Joshua Lochner commited on
Commit
f9281a4
1 Parent(s): 7e65770

Cache classifier after download

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -17,7 +17,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src
17
 
18
  from predict import SegmentationArguments, ClassifierArguments, predict as pred, seconds_to_time # noqa
19
  from evaluate import EvaluationArguments
20
- from shared import device
21
 
22
  st.set_page_config(
23
  page_title='SponsorBlock ML',
@@ -38,10 +38,11 @@ st.set_page_config(
38
 
39
 
40
  # Faster caching system for predictions (No need to hash)
41
- @st.cache(allow_output_mutation=True)
42
  def persistdata():
43
  return {}
44
 
 
45
  prediction_cache = persistdata()
46
 
47
  MODELS = {
@@ -65,16 +66,27 @@ for m in MODELS:
65
  if m not in prediction_cache:
66
  prediction_cache[m] = {}
67
 
68
- CATGEGORY_OPTIONS = {
69
- 'SPONSOR': 'Sponsor',
70
- 'SELFPROMO': 'Self/unpaid promo',
71
- 'INTERACTION': 'Interaction reminder',
72
- }
73
 
74
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
75
 
76
 
77
- @st.cache(allow_output_mutation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def load_predict(model_id):
79
  model_info = MODELS[model_id]
80
 
@@ -88,17 +100,7 @@ def load_predict(model_id):
88
 
89
  tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
90
 
91
- # Save classifier and vectorizer
92
- hf_hub_download(repo_id=CLASSIFIER_PATH,
93
- filename=classifier_args.classifier_file,
94
- cache_dir=classifier_args.classifier_dir,
95
- force_filename=classifier_args.classifier_file,
96
- )
97
- hf_hub_download(repo_id=CLASSIFIER_PATH,
98
- filename=classifier_args.vectorizer_file,
99
- cache_dir=classifier_args.classifier_dir,
100
- force_filename=classifier_args.vectorizer_file,
101
- )
102
 
103
  def predict_function(video_id):
104
  if video_id not in prediction_cache[model_id]:
@@ -187,9 +189,8 @@ def main():
187
  json_data = quote(json.dumps(submit_segments))
188
  link = f'[Submit Segments](https://www.youtube.com/watch?v={video_id}#segments={json_data})'
189
  st.markdown(link, unsafe_allow_html=True)
190
- wikiLink = f'[Review generated segments before submitting!](https://wiki.sponsor.ajay.app/w/Automating_Submissions)'
191
- st.markdown(wikiLink, unsafe_allow_html=True)
192
-
193
 
194
  if __name__ == '__main__':
195
  main()
 
17
 
18
  from predict import SegmentationArguments, ClassifierArguments, predict as pred, seconds_to_time # noqa
19
  from evaluate import EvaluationArguments
20
+ from shared import device, CATGEGORY_OPTIONS
21
 
22
  st.set_page_config(
23
  page_title='SponsorBlock ML',
 
38
 
39
 
40
  # Faster caching system for predictions (No need to hash)
41
+ @st.cache(persist=True, allow_output_mutation=True)
42
  def persistdata():
43
  return {}
44
 
45
+
46
  prediction_cache = persistdata()
47
 
48
  MODELS = {
 
66
  if m not in prediction_cache:
67
  prediction_cache[m] = {}
68
 
 
 
 
 
 
69
 
70
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
71
 
72
 
73
+ @st.cache(persist=True, allow_output_mutation=True)
74
+ def download_classifier(classifier_args):
75
+ # Save classifier and vectorizer
76
+ hf_hub_download(repo_id=CLASSIFIER_PATH,
77
+ filename=classifier_args.classifier_file,
78
+ cache_dir=classifier_args.classifier_dir,
79
+ force_filename=classifier_args.classifier_file,
80
+ )
81
+ hf_hub_download(repo_id=CLASSIFIER_PATH,
82
+ filename=classifier_args.vectorizer_file,
83
+ cache_dir=classifier_args.classifier_dir,
84
+ force_filename=classifier_args.vectorizer_file,
85
+ )
86
+ return True
87
+
88
+
89
+ @st.cache(persist=True, allow_output_mutation=True)
90
  def load_predict(model_id):
91
  model_info = MODELS[model_id]
92
 
 
100
 
101
  tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
102
 
103
+ download_classifier(classifier_args)
 
 
 
 
 
 
 
 
 
 
104
 
105
  def predict_function(video_id):
106
  if video_id not in prediction_cache[model_id]:
 
189
  json_data = quote(json.dumps(submit_segments))
190
  link = f'[Submit Segments](https://www.youtube.com/watch?v={video_id}#segments={json_data})'
191
  st.markdown(link, unsafe_allow_html=True)
192
+ wiki_link = '[Review generated segments before submitting!](https://wiki.sponsor.ajay.app/w/Automating_Submissions)'
193
+ st.markdown(wiki_link, unsafe_allow_html=True)
 
194
 
195
  if __name__ == '__main__':
196
  main()