lewtun HF staff commited on
Commit
eef70c0
1 Parent(s): 624eaa5
Files changed (3) hide show
  1. app.py +1 -6
  2. evaluation.py +1 -1
  3. utils.py +4 -1
app.py CHANGED
@@ -346,12 +346,7 @@ with st.expander("Advanced configuration"):
346
  )
347
 
348
  with st.form(key="form"):
349
- # Grab all models fine-tuned on SQuAD for question answering tasks
350
- if selected_task == "extractive_question_answering":
351
- compatible_models = get_compatible_models(selected_task, [selected_dataset, "squad", "squad_v2"])
352
- else:
353
- compatible_models = get_compatible_models(selected_task, [selected_dataset])
354
-
355
  selected_models = st.multiselect(
356
  "Select the models you wish to evaluate",
357
  compatible_models,
 
346
  )
347
 
348
  with st.form(key="form"):
349
+ compatible_models = get_compatible_models(selected_task, [selected_dataset])
 
 
 
 
 
350
  selected_models = st.multiselect(
351
  "Select the models you wish to evaluate",
352
  compatible_models,
evaluation.py CHANGED
@@ -43,7 +43,7 @@ def filter_evaluated_models(models, task, dataset_name, dataset_config, dataset_
43
  )
44
  candidate_id = hash(evaluation_info)
45
  if candidate_id in evaluation_ids:
46
- st.info(f"Model {model} has already been evaluated on this configuration. Skipping evaluation...")
47
  models.pop(idx)
48
 
49
  return models
 
43
  )
44
  candidate_id = hash(evaluation_info)
45
  if candidate_id in evaluation_ids:
46
+ st.info(f"Model `{model}` has already been evaluated on this configuration. Skipping evaluation...")
47
  models.pop(idx)
48
 
49
  return models
utils.py CHANGED
@@ -75,6 +75,9 @@ def get_compatible_models(task: str, dataset_ids: List[str]) -> List[str]:
75
  """
76
  # TODO: relax filter on PyTorch models if TensorFlow supported in AutoTrain
77
  compatible_models = []
 
 
 
78
  for dataset_id in dataset_ids:
79
  model_filter = ModelFilter(
80
  task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
@@ -82,7 +85,7 @@ def get_compatible_models(task: str, dataset_ids: List[str]) -> List[str]:
82
  library=["transformers", "pytorch"],
83
  )
84
  compatible_models.extend(HfApi().list_models(filter=model_filter))
85
- return sorted([model.modelId for model in compatible_models])
86
 
87
 
88
  def get_key(col_mapping, val):
 
75
  """
76
  # TODO: relax filter on PyTorch models if TensorFlow supported in AutoTrain
77
  compatible_models = []
78
+ if task == "extractive_question_answering":
79
+ dataset_ids.extend(["squad", "squad_v2"])
80
+
81
  for dataset_id in dataset_ids:
82
  model_filter = ModelFilter(
83
  task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
 
85
  library=["transformers", "pytorch"],
86
  )
87
  compatible_models.extend(HfApi().list_models(filter=model_filter))
88
+ return set(sorted([model.modelId for model in compatible_models]))
89
 
90
 
91
  def get_key(col_mapping, val):