Commit
•
d9ea03e
1
Parent(s):
e7c8363
Update app.py
Browse files
app.py
CHANGED
@@ -23,6 +23,25 @@ def get_user_models(hf_username, env_tag, lib_tag):
|
|
23 |
return user_model_ids
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def get_metadata(model_id):
|
27 |
"""
|
28 |
Get model metadata (contains evaluation data)
|
@@ -208,8 +227,8 @@ def certification(hf_username):
|
|
208 |
},
|
209 |
{
|
210 |
"unit": "Unit 8 PII",
|
211 |
-
"env": "
|
212 |
-
"library": "
|
213 |
"min_result": 100,
|
214 |
"best_result": 0,
|
215 |
"best_model_id": "",
|
@@ -217,8 +236,13 @@ def certification(hf_username):
|
|
217 |
},
|
218 |
]
|
219 |
for unit in results_certification:
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
# Calculate the best result and get the best_model_id
|
224 |
best_result, best_model_id = calculate_best_result(user_models)
|
|
|
23 |
return user_model_ids
|
24 |
|
25 |
|
26 |
+
def get_user_sf_models(hf_username, env_tag, lib_tag):
|
27 |
+
api = HfApi()
|
28 |
+
models_sf = []
|
29 |
+
models = api.list_models(author=hf_username, filter=["reinforcement-learning", lib_tag])
|
30 |
+
|
31 |
+
user_model_ids = [x.modelId for x in models]
|
32 |
+
|
33 |
+
for model in user_model_ids:
|
34 |
+
meta = get_metadata(model)
|
35 |
+
if meta is None:
|
36 |
+
continue
|
37 |
+
result = meta["model-index"][0]["results"][0]["dataset"]["name"]
|
38 |
+
if result == env_tag:
|
39 |
+
models_sf.append(model)
|
40 |
+
|
41 |
+
user_sf_models_ids = [x.modelId for x in models_sf]
|
42 |
+
return user_sf_models_ids
|
43 |
+
|
44 |
+
|
45 |
def get_metadata(model_id):
|
46 |
"""
|
47 |
Get model metadata (contains evaluation data)
|
|
|
227 |
},
|
228 |
{
|
229 |
"unit": "Unit 8 PII",
|
230 |
+
"env": "doom_health_gathering_supreme",
|
231 |
+
"library": "sample-factory",
|
232 |
"min_result": 100,
|
233 |
"best_result": 0,
|
234 |
"best_model_id": "",
|
|
|
236 |
},
|
237 |
]
|
238 |
for unit in results_certification:
|
239 |
+
if unit["unit"] != "Unit 8 PII":
|
240 |
+
# Get user model
|
241 |
+
user_models = get_user_models(hf_username, unit['env'], unit['library'])
|
242 |
+
# For sample factory vizdoom we don't have env tag for now
|
243 |
+
else:
|
244 |
+
user_models = get_user_sf_models(hf_username, unit['env'], unit['library'])
|
245 |
+
|
246 |
|
247 |
# Calculate the best result and get the best_model_id
|
248 |
best_result, best_model_id = calculate_best_result(user_models)
|