Spaces:
Runtime error
Runtime error
Commit
•
d68b7d5
1
Parent(s):
fe69431
add zero shot classification task (#45)
Browse files* add zero shot classification task
* fix default metric list for zero shot classification
* Update enum
* Rename to text_zero_shot_classification
* Merge conflict
* Lewis refactor incorporate
* Adhere to Hub naming conventions
* Incorporate Autotrain changes for deprecated data endpoint
* Sagemaker update
* Sagemaker changes
Co-authored-by: mathemakitten <[email protected]>
Co-authored-by: Lewis Tunstall <[email protected]>
Co-authored-by: helen <[email protected]>
app.py
CHANGED
@@ -43,6 +43,7 @@ TASK_TO_ID = {
|
|
43 |
"extractive_question_answering": 5,
|
44 |
"translation": 6,
|
45 |
"summarization": 8,
|
|
|
46 |
}
|
47 |
|
48 |
TASK_TO_DEFAULT_METRICS = {
|
@@ -65,6 +66,7 @@ TASK_TO_DEFAULT_METRICS = {
|
|
65 |
"recall",
|
66 |
"accuracy",
|
67 |
],
|
|
|
68 |
}
|
69 |
|
70 |
AUTOTRAIN_TASK_TO_LANG = {
|
@@ -73,6 +75,8 @@ AUTOTRAIN_TASK_TO_LANG = {
|
|
73 |
"image_multi_class_classification": "unk",
|
74 |
}
|
75 |
|
|
|
|
|
76 |
|
77 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
78 |
|
@@ -273,6 +277,45 @@ with st.expander("Advanced configuration"):
|
|
273 |
col_mapping[text_col] = "text"
|
274 |
col_mapping[target_col] = "target"
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
if selected_task in ["natural_language_inference"]:
|
277 |
config_metadata = get_config_metadata(selected_config, metadata)
|
278 |
with col1:
|
@@ -533,8 +576,10 @@ with st.form(key="form"):
|
|
533 |
else "en",
|
534 |
"max_models": 5,
|
535 |
"instance": {
|
536 |
-
"provider": "
|
537 |
-
"instance_type":
|
|
|
|
|
538 |
"max_runtime_seconds": 172800,
|
539 |
"num_instances": 1,
|
540 |
"disk_size_gb": 150,
|
@@ -560,17 +605,15 @@ with st.form(key="form"):
|
|
560 |
"split": 4, # use "auto" split choice in AutoTrain
|
561 |
"col_mapping": col_mapping,
|
562 |
"load_config": {"max_size_bytes": 0, "shuffle": False},
|
|
|
|
|
|
|
563 |
}
|
564 |
data_json_resp = http_post(
|
565 |
-
path=f"/projects/{project_json_resp['id']}/data/
|
566 |
payload=data_payload,
|
567 |
token=HF_TOKEN,
|
568 |
domain=AUTOTRAIN_BACKEND_API,
|
569 |
-
params={
|
570 |
-
"type": "dataset",
|
571 |
-
"config_name": selected_config,
|
572 |
-
"split_name": selected_split,
|
573 |
-
},
|
574 |
).json()
|
575 |
print(f"INFO -- Dataset creation response: {data_json_resp}")
|
576 |
if data_json_resp["download_status"] == 1:
|
|
|
43 |
"extractive_question_answering": 5,
|
44 |
"translation": 6,
|
45 |
"summarization": 8,
|
46 |
+
"text_zero_shot_classification": 23,
|
47 |
}
|
48 |
|
49 |
TASK_TO_DEFAULT_METRICS = {
|
|
|
66 |
"recall",
|
67 |
"accuracy",
|
68 |
],
|
69 |
+
"text_zero_shot_classification": ["accuracy", "loss"],
|
70 |
}
|
71 |
|
72 |
AUTOTRAIN_TASK_TO_LANG = {
|
|
|
75 |
"image_multi_class_classification": "unk",
|
76 |
}
|
77 |
|
78 |
+
AUTOTRAIN_MACHINE = {"text_zero_shot_classification": "r5.16x"}
|
79 |
+
|
80 |
|
81 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
82 |
|
|
|
277 |
col_mapping[text_col] = "text"
|
278 |
col_mapping[target_col] = "target"
|
279 |
|
280 |
+
elif selected_task == "text_zero_shot_classification":
|
281 |
+
with col1:
|
282 |
+
st.markdown("`text` column")
|
283 |
+
st.text("")
|
284 |
+
st.text("")
|
285 |
+
st.text("")
|
286 |
+
st.text("")
|
287 |
+
st.markdown("`classes` column")
|
288 |
+
st.text("")
|
289 |
+
st.text("")
|
290 |
+
st.text("")
|
291 |
+
st.text("")
|
292 |
+
st.markdown("`target` column")
|
293 |
+
with col2:
|
294 |
+
text_col = st.selectbox(
|
295 |
+
"This column should contain the text to be classified",
|
296 |
+
col_names,
|
297 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
298 |
+
if config_metadata is not None
|
299 |
+
else 0,
|
300 |
+
)
|
301 |
+
classes_col = st.selectbox(
|
302 |
+
"This column should contain the classes associated with the text",
|
303 |
+
col_names,
|
304 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "classes"))
|
305 |
+
if config_metadata is not None
|
306 |
+
else 0,
|
307 |
+
)
|
308 |
+
target_col = st.selectbox(
|
309 |
+
"This column should contain the index of the correct class",
|
310 |
+
col_names,
|
311 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
312 |
+
if config_metadata is not None
|
313 |
+
else 0,
|
314 |
+
)
|
315 |
+
col_mapping[text_col] = "text"
|
316 |
+
col_mapping[classes_col] = "classes"
|
317 |
+
col_mapping[target_col] = "target"
|
318 |
+
|
319 |
if selected_task in ["natural_language_inference"]:
|
320 |
config_metadata = get_config_metadata(selected_config, metadata)
|
321 |
with col1:
|
|
|
576 |
else "en",
|
577 |
"max_models": 5,
|
578 |
"instance": {
|
579 |
+
"provider": "sagemaker",
|
580 |
+
"instance_type": AUTOTRAIN_MACHINE[selected_task]
|
581 |
+
if selected_task in AUTOTRAIN_MACHINE.keys()
|
582 |
+
else "p3",
|
583 |
"max_runtime_seconds": 172800,
|
584 |
"num_instances": 1,
|
585 |
"disk_size_gb": 150,
|
|
|
605 |
"split": 4, # use "auto" split choice in AutoTrain
|
606 |
"col_mapping": col_mapping,
|
607 |
"load_config": {"max_size_bytes": 0, "shuffle": False},
|
608 |
+
"dataset_id": selected_dataset,
|
609 |
+
"dataset_config": selected_config,
|
610 |
+
"dataset_split": selected_split,
|
611 |
}
|
612 |
data_json_resp = http_post(
|
613 |
+
path=f"/projects/{project_json_resp['id']}/data/dataset",
|
614 |
payload=data_payload,
|
615 |
token=HF_TOKEN,
|
616 |
domain=AUTOTRAIN_BACKEND_API,
|
|
|
|
|
|
|
|
|
|
|
617 |
).json()
|
618 |
print(f"INFO -- Dataset creation response: {data_json_resp}")
|
619 |
if data_json_resp["download_status"] == 1:
|
utils.py
CHANGED
@@ -19,6 +19,7 @@ AUTOTRAIN_TASK_TO_HUB_TASK = {
|
|
19 |
"summarization": "summarization",
|
20 |
"image_binary_classification": "image-classification",
|
21 |
"image_multi_class_classification": "image-classification",
|
|
|
22 |
}
|
23 |
|
24 |
|
@@ -82,7 +83,8 @@ def get_compatible_models(task: str, dataset_ids: List[str]) -> List[str]:
|
|
82 |
"""
|
83 |
compatible_models = []
|
84 |
# Allow any summarization model to be used for summarization tasks
|
85 |
-
|
|
|
86 |
model_filter = ModelFilter(
|
87 |
task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
|
88 |
library=["transformers", "pytorch"],
|
@@ -195,9 +197,11 @@ def create_autotrain_project_name(dataset_id: str, dataset_config: str) -> str:
|
|
195 |
"""Creates an AutoTrain project name for the given dataset ID."""
|
196 |
# Project names cannot have "/", so we need to format community datasets accordingly
|
197 |
dataset_id_formatted = dataset_id.replace("/", "__")
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
201 |
|
202 |
|
203 |
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|
|
|
19 |
"summarization": "summarization",
|
20 |
"image_binary_classification": "image-classification",
|
21 |
"image_multi_class_classification": "image-classification",
|
22 |
+
"text_zero_shot_classification": "text-generation",
|
23 |
}
|
24 |
|
25 |
|
|
|
83 |
"""
|
84 |
compatible_models = []
|
85 |
# Allow any summarization model to be used for summarization tasks
|
86 |
+
# and allow any text-generation model to be used for text_zero_shot_classification
|
87 |
+
if task in ("summarization", "text_zero_shot_classification"):
|
88 |
model_filter = ModelFilter(
|
89 |
task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
|
90 |
library=["transformers", "pytorch"],
|
|
|
197 |
"""Creates an AutoTrain project name for the given dataset ID."""
|
198 |
# Project names cannot have "/", so we need to format community datasets accordingly
|
199 |
dataset_id_formatted = dataset_id.replace("/", "__")
|
200 |
+
dataset_config_formatted = dataset_config.replace("--", "__")
|
201 |
+
# Project names need to be unique, so we append a random string to guarantee this while adhering to naming rules
|
202 |
+
basename = f"eval-{dataset_id_formatted}-{dataset_config_formatted}"
|
203 |
+
basename = basename[:60] if len(basename) > 60 else basename # Hub naming limitation
|
204 |
+
return f"{basename}-{str(uuid.uuid4())[:6]}"
|
205 |
|
206 |
|
207 |
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|