Spaces:
Runtime error
Runtime error
Add comment
Browse files
utils.py
CHANGED
@@ -73,11 +73,13 @@ def get_compatible_models(task: str, dataset_ids: List[str]) -> List[str]:
|
|
73 |
Returns:
|
74 |
A list of model IDs, sorted alphabetically.
|
75 |
"""
|
76 |
-
# TODO: relax filter on PyTorch models if TensorFlow supported in AutoTrain
|
77 |
compatible_models = []
|
|
|
|
|
78 |
if task == "extractive_question_answering":
|
79 |
dataset_ids.extend(["squad", "squad_v2"])
|
80 |
|
|
|
81 |
for dataset_id in dataset_ids:
|
82 |
model_filter = ModelFilter(
|
83 |
task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
|
|
|
73 |
Returns:
|
74 |
A list of model IDs, sorted alphabetically.
|
75 |
"""
|
|
|
76 |
compatible_models = []
|
77 |
+
# Include models trained on SQuAD datasets, since these can be evaluated on
|
78 |
+
# other SQuAD-like datasets
|
79 |
if task == "extractive_question_answering":
|
80 |
dataset_ids.extend(["squad", "squad_v2"])
|
81 |
|
82 |
+
# TODO: relax filter on PyTorch models if TensorFlow supported in AutoTrain
|
83 |
for dataset_id in dataset_ids:
|
84 |
model_filter = ModelFilter(
|
85 |
task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
|