Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
c415610
1
Parent(s):
85661b3
Use custom caching system for loading models
Browse files
app.py
CHANGED
@@ -40,11 +40,17 @@ st.set_page_config(
|
|
40 |
|
41 |
# Faster caching system for predictions (No need to hash)
|
42 |
@st.cache(persist=True, allow_output_mutation=True)
|
43 |
-
def
|
44 |
return {}
|
45 |
|
46 |
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
MODELS = {
|
50 |
'Small (293 MB)': {
|
@@ -100,23 +106,27 @@ def predict_function(model_id, model, tokenizer, segmentation_args, classifier_a
|
|
100 |
return prediction_cache[model_id][video_id]
|
101 |
|
102 |
|
103 |
-
@st.cache(persist=True, allow_output_mutation=True)
|
104 |
def load_predict(model_id):
|
105 |
model_info = MODELS[model_id]
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
model.to(device())
|
114 |
|
115 |
-
|
116 |
|
117 |
-
|
|
|
118 |
|
119 |
-
return
|
120 |
|
121 |
|
122 |
def main():
|
|
|
40 |
|
41 |
# Faster caching system for predictions (No need to hash)
|
42 |
@st.cache(persist=True, allow_output_mutation=True)
|
43 |
+
def create_prediction_cache():
|
44 |
return {}
|
45 |
|
46 |
|
47 |
+
@st.cache(persist=True, allow_output_mutation=True)
|
48 |
+
def create_function_cache():
|
49 |
+
return {}
|
50 |
+
|
51 |
+
|
52 |
+
prediction_cache = create_prediction_cache()
|
53 |
+
prediction_function_cache = create_function_cache()
|
54 |
|
55 |
MODELS = {
|
56 |
'Small (293 MB)': {
|
|
|
106 |
return prediction_cache[model_id][video_id]
|
107 |
|
108 |
|
|
|
109 |
def load_predict(model_id):
|
110 |
model_info = MODELS[model_id]
|
111 |
|
112 |
+
if model_id not in prediction_function_cache:
|
113 |
+
# Use default segmentation and classification arguments
|
114 |
+
evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
|
115 |
+
segmentation_args = SegmentationArguments()
|
116 |
+
classifier_args = ClassifierArguments()
|
117 |
+
|
118 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
119 |
+
evaluation_args.model_path)
|
120 |
+
model.to(device())
|
121 |
|
122 |
+
tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
|
|
|
123 |
|
124 |
+
download_classifier(classifier_args)
|
125 |
|
126 |
+
prediction_function_cache[model_id] = partial(
|
127 |
+
predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
|
128 |
|
129 |
+
return prediction_function_cache[model_id]
|
130 |
|
131 |
|
132 |
def main():
|