Joshua Lochner commited on
Commit
c415610
1 Parent(s): 85661b3

Use custom caching system for loading models

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -40,11 +40,17 @@ st.set_page_config(
40
 
41
  # Faster caching system for predictions (No need to hash)
42
  @st.cache(persist=True, allow_output_mutation=True)
43
- def persistdata():
44
  return {}
45
 
46
 
47
- prediction_cache = persistdata()
 
 
 
 
 
 
48
 
49
  MODELS = {
50
  'Small (293 MB)': {
@@ -100,23 +106,27 @@ def predict_function(model_id, model, tokenizer, segmentation_args, classifier_a
100
  return prediction_cache[model_id][video_id]
101
 
102
 
103
- @st.cache(persist=True, allow_output_mutation=True)
104
  def load_predict(model_id):
105
  model_info = MODELS[model_id]
106
 
107
- # Use default segmentation and classification arguments
108
- evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
109
- segmentation_args = SegmentationArguments()
110
- classifier_args = ClassifierArguments()
 
 
 
 
 
111
 
112
- model = AutoModelForSeq2SeqLM.from_pretrained(evaluation_args.model_path)
113
- model.to(device())
114
 
115
- tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
116
 
117
- download_classifier(classifier_args)
 
118
 
119
- return partial(predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
120
 
121
 
122
  def main():
 
40
 
41
  # Faster caching system for predictions (No need to hash)
42
  @st.cache(persist=True, allow_output_mutation=True)
43
+ def create_prediction_cache():
44
  return {}
45
 
46
 
47
+ @st.cache(persist=True, allow_output_mutation=True)
48
+ def create_function_cache():
49
+ return {}
50
+
51
+
52
+ prediction_cache = create_prediction_cache()
53
+ prediction_function_cache = create_function_cache()
54
 
55
  MODELS = {
56
  'Small (293 MB)': {
 
106
  return prediction_cache[model_id][video_id]
107
 
108
 
 
109
  def load_predict(model_id):
110
  model_info = MODELS[model_id]
111
 
112
+ if model_id not in prediction_function_cache:
113
+ # Use default segmentation and classification arguments
114
+ evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
115
+ segmentation_args = SegmentationArguments()
116
+ classifier_args = ClassifierArguments()
117
+
118
+ model = AutoModelForSeq2SeqLM.from_pretrained(
119
+ evaluation_args.model_path)
120
+ model.to(device())
121
 
122
+ tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
 
123
 
124
+ download_classifier(classifier_args)
125
 
126
+ prediction_function_cache[model_id] = partial(
127
+ predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
128
 
129
+ return prediction_function_cache[model_id]
130
 
131
 
132
  def main():