Pendrokar commited on
Commit
e6bb72e
1 Parent(s): ac04cbe

multithread only when one zerogpu space; cache sample even if one errors out

Browse files
Files changed (1) hide show
  1. app.py +55 -31
app.py CHANGED
@@ -1140,6 +1140,28 @@ def synthandreturn(text, request: gr.Request):
1140
  pass
1141
 
1142
  return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1143
 
1144
  mdl1k = mdl1
1145
  mdl2k = mdl2
@@ -1148,15 +1170,39 @@ def synthandreturn(text, request: gr.Request):
1148
  if mdl2 in AVAILABLE_MODELS.keys(): mdl2k=AVAILABLE_MODELS[mdl2]
1149
  results = {}
1150
  print(f"Sending models {mdl1k} and {mdl2k} to API")
1151
- # thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1k, results, request))
1152
- # thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2k, results, request))
1153
-
1154
- # thread1.start()
1155
- # thread2.start()
1156
- # thread1.join(180)
1157
- # thread2.join(180)
1158
- predict_and_update_result(text, mdl1k, results, request)
1159
- predict_and_update_result(text, mdl2k, results, request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1160
  #debug
1161
  # print(results)
1162
  # print(list(results.keys())[0])
@@ -1168,28 +1214,6 @@ def synthandreturn(text, request: gr.Request):
1168
  #debug
1169
  # outputs = [text, btn, r2, model1, model2, aud1, aud2, abetter, bbetter, prevmodel1, prevmodel2, nxtroundbtn]
1170
 
1171
- # cache the result
1172
- for model in [mdl1k, mdl2k]:
1173
- # skip caching if not hardcoded sentence
1174
- if (text not in sents):
1175
- break
1176
-
1177
- already_cached = False
1178
- # check if already cached
1179
- for cached_sample in cached_samples:
1180
- # TODO:replace cached
1181
- if (cached_sample.transcript == text and cached_sample.modelName == model):
1182
- already_cached = True
1183
- break
1184
-
1185
- if (already_cached):
1186
- continue
1187
-
1188
- try:
1189
- cached_samples.append(Sample(results[model], text, model))
1190
- except:
1191
- pass
1192
-
1193
  # all_pairs = generate_matching_pairs(cached_samples)
1194
 
1195
  print(f"Retrieving models {mdl1k} and {mdl2k} from API")
 
1140
  pass
1141
 
1142
  return inputs
1143
+
1144
+ def _cache_sample(text, model):
1145
+ # skip caching if not hardcoded sentence
1146
+ if (text not in sents):
1147
+ return False
1148
+
1149
+ already_cached = False
1150
+ # check if already cached
1151
+ for cached_sample in cached_samples:
1152
+ # TODO:replace cached with newer version
1153
+ if (cached_sample.transcript == text and cached_sample.modelName == model):
1154
+ already_cached = True
1155
+ return True
1156
+
1157
+ if (already_cached):
1158
+ return False
1159
+
1160
+ try:
1161
+ cached_samples.append(Sample(results[model], text, model))
1162
+ except:
1163
+ print('Error when trying to cache sample')
1164
+ return False
1165
 
1166
  mdl1k = mdl1
1167
  mdl2k = mdl2
 
1170
  if mdl2 in AVAILABLE_MODELS.keys(): mdl2k=AVAILABLE_MODELS[mdl2]
1171
  results = {}
1172
  print(f"Sending models {mdl1k} and {mdl2k} to API")
1173
+
1174
+ # do not use multithreading when both spaces are ZeroGPU type
1175
+ if (
1176
+ # exists
1177
+ 'is_zero_gpu_space' in HF_SPACES[mdl1]
1178
+ # is True
1179
+ and HF_SPACES[mdl1]['is_zero_gpu_space']
1180
+ and 'is_zero_gpu_space' in HF_SPACES[mdl2]
1181
+ and HF_SPACES[mdl2]['is_zero_gpu_space']
1182
+ ):
1183
+ # run Zero-GPU spaces one at a time
1184
+ predict_and_update_result(text, mdl1k, results, request)
1185
+ _cache_sample(text, mdl1k)
1186
+
1187
+ predict_and_update_result(text, mdl2k, results, request)
1188
+ _cache_sample(text, mdl2k)
1189
+ else:
1190
+ # use multithreading
1191
+ thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1k, results, request))
1192
+ thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2k, results, request))
1193
+
1194
+ thread1.start()
1195
+ # wait 3 seconds to calm hf.space domain
1196
+ time.sleep(3)
1197
+ thread2.start()
1198
+ # timeout in 2 minutes
1199
+ thread1.join(120)
1200
+ thread2.join(120)
1201
+
1202
+ # cache the result
1203
+ for model in [mdl1k, mdl2k]:
1204
+ _cache_sample(text, model)
1205
+
1206
  #debug
1207
  # print(results)
1208
  # print(list(results.keys())[0])
 
1214
  #debug
1215
  # outputs = [text, btn, r2, model1, model2, aud1, aud2, abetter, bbetter, prevmodel1, prevmodel2, nxtroundbtn]
1216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1217
  # all_pairs = generate_matching_pairs(cached_samples)
1218
 
1219
  print(f"Retrieving models {mdl1k} and {mdl2k} from API")