ThomasSimonini HF staff commited on
Commit
d9ea03e
1 Parent(s): e7c8363

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -23,6 +23,25 @@ def get_user_models(hf_username, env_tag, lib_tag):
23
  return user_model_ids
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_metadata(model_id):
27
  """
28
  Get model metadata (contains evaluation data)
@@ -208,8 +227,8 @@ def certification(hf_username):
208
  },
209
  {
210
  "unit": "Unit 8 PII",
211
- "env": "Vizdoom-Battle",
212
- "library": "cleanrl",
213
  "min_result": 100,
214
  "best_result": 0,
215
  "best_model_id": "",
@@ -217,8 +236,13 @@ def certification(hf_username):
217
  },
218
  ]
219
  for unit in results_certification:
220
- # Get user model
221
- user_models = get_user_models(hf_username, unit['env'], unit['library'])
 
 
 
 
 
222
 
223
  # Calculate the best result and get the best_model_id
224
  best_result, best_model_id = calculate_best_result(user_models)
 
23
  return user_model_ids
24
 
25
 
26
+ def get_user_sf_models(hf_username, env_tag, lib_tag):
27
+ api = HfApi()
28
+ models_sf = []
29
+ models = api.list_models(author=hf_username, filter=["reinforcement-learning", lib_tag])
30
+
31
+ user_model_ids = [x.modelId for x in models]
32
+
33
+ for model in user_model_ids:
34
+ meta = get_metadata(model)
35
+ if meta is None:
36
+ continue
37
+ result = meta["model-index"][0]["results"][0]["dataset"]["name"]
38
+ if result == env_tag:
39
+ models_sf.append(model)
40
+
41
+ user_sf_models_ids = [x.modelId for x in models_sf]
42
+ return user_sf_models_ids
43
+
44
+
45
  def get_metadata(model_id):
46
  """
47
  Get model metadata (contains evaluation data)
 
227
  },
228
  {
229
  "unit": "Unit 8 PII",
230
+ "env": "doom_health_gathering_supreme",
231
+ "library": "sample-factory",
232
  "min_result": 100,
233
  "best_result": 0,
234
  "best_model_id": "",
 
236
  },
237
  ]
238
  for unit in results_certification:
239
+ if unit["unit"] != "Unit 8 PII":
240
+ # Get user model
241
+ user_models = get_user_models(hf_username, unit['env'], unit['library'])
242
+ # For sample factory vizdoom we don't have env tag for now
243
+ else:
244
+ user_models = get_user_sf_models(hf_username, unit['env'], unit['library'])
245
+
246
 
247
  # Calculate the best result and get the best_model_id
248
  best_result, best_model_id = calculate_best_result(user_models)