TristanThrush mathemakitten lewtun HF staff helen commited on
Commit
d68b7d5
1 Parent(s): fe69431

add zero shot classification task (#45)

Browse files

* add zero shot classification task

* fix default metric list for zero shot classification

* Update enum

* Rename to text_zero_shot_classification

* Merge conflict

* Lewis refactor incorporate

* Adhere to Hub naming conventions

* Incorporate Autotrain changes for deprecated data endpoint

* Sagemaker update

* Sagemaker changes

Co-authored-by: mathemakitten <[email protected]>
Co-authored-by: Lewis Tunstall <[email protected]>
Co-authored-by: helen <[email protected]>

Files changed (2) hide show
  1. app.py +51 -8
  2. utils.py +8 -4
app.py CHANGED
@@ -43,6 +43,7 @@ TASK_TO_ID = {
43
  "extractive_question_answering": 5,
44
  "translation": 6,
45
  "summarization": 8,
 
46
  }
47
 
48
  TASK_TO_DEFAULT_METRICS = {
@@ -65,6 +66,7 @@ TASK_TO_DEFAULT_METRICS = {
65
  "recall",
66
  "accuracy",
67
  ],
 
68
  }
69
 
70
  AUTOTRAIN_TASK_TO_LANG = {
@@ -73,6 +75,8 @@ AUTOTRAIN_TASK_TO_LANG = {
73
  "image_multi_class_classification": "unk",
74
  }
75
 
 
 
76
 
77
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
78
 
@@ -273,6 +277,45 @@ with st.expander("Advanced configuration"):
273
  col_mapping[text_col] = "text"
274
  col_mapping[target_col] = "target"
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  if selected_task in ["natural_language_inference"]:
277
  config_metadata = get_config_metadata(selected_config, metadata)
278
  with col1:
@@ -533,8 +576,10 @@ with st.form(key="form"):
533
  else "en",
534
  "max_models": 5,
535
  "instance": {
536
- "provider": "aws",
537
- "instance_type": "ml.g4dn.4xlarge",
 
 
538
  "max_runtime_seconds": 172800,
539
  "num_instances": 1,
540
  "disk_size_gb": 150,
@@ -560,17 +605,15 @@ with st.form(key="form"):
560
  "split": 4, # use "auto" split choice in AutoTrain
561
  "col_mapping": col_mapping,
562
  "load_config": {"max_size_bytes": 0, "shuffle": False},
 
 
 
563
  }
564
  data_json_resp = http_post(
565
- path=f"/projects/{project_json_resp['id']}/data/{selected_dataset}",
566
  payload=data_payload,
567
  token=HF_TOKEN,
568
  domain=AUTOTRAIN_BACKEND_API,
569
- params={
570
- "type": "dataset",
571
- "config_name": selected_config,
572
- "split_name": selected_split,
573
- },
574
  ).json()
575
  print(f"INFO -- Dataset creation response: {data_json_resp}")
576
  if data_json_resp["download_status"] == 1:
 
43
  "extractive_question_answering": 5,
44
  "translation": 6,
45
  "summarization": 8,
46
+ "text_zero_shot_classification": 23,
47
  }
48
 
49
  TASK_TO_DEFAULT_METRICS = {
 
66
  "recall",
67
  "accuracy",
68
  ],
69
+ "text_zero_shot_classification": ["accuracy", "loss"],
70
  }
71
 
72
  AUTOTRAIN_TASK_TO_LANG = {
 
75
  "image_multi_class_classification": "unk",
76
  }
77
 
78
+ AUTOTRAIN_MACHINE = {"text_zero_shot_classification": "r5.16x"}
79
+
80
 
81
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
82
 
 
277
  col_mapping[text_col] = "text"
278
  col_mapping[target_col] = "target"
279
 
280
+ elif selected_task == "text_zero_shot_classification":
281
+ with col1:
282
+ st.markdown("`text` column")
283
+ st.text("")
284
+ st.text("")
285
+ st.text("")
286
+ st.text("")
287
+ st.markdown("`classes` column")
288
+ st.text("")
289
+ st.text("")
290
+ st.text("")
291
+ st.text("")
292
+ st.markdown("`target` column")
293
+ with col2:
294
+ text_col = st.selectbox(
295
+ "This column should contain the text to be classified",
296
+ col_names,
297
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
298
+ if config_metadata is not None
299
+ else 0,
300
+ )
301
+ classes_col = st.selectbox(
302
+ "This column should contain the classes associated with the text",
303
+ col_names,
304
+ index=col_names.index(get_key(config_metadata["col_mapping"], "classes"))
305
+ if config_metadata is not None
306
+ else 0,
307
+ )
308
+ target_col = st.selectbox(
309
+ "This column should contain the index of the correct class",
310
+ col_names,
311
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
312
+ if config_metadata is not None
313
+ else 0,
314
+ )
315
+ col_mapping[text_col] = "text"
316
+ col_mapping[classes_col] = "classes"
317
+ col_mapping[target_col] = "target"
318
+
319
  if selected_task in ["natural_language_inference"]:
320
  config_metadata = get_config_metadata(selected_config, metadata)
321
  with col1:
 
576
  else "en",
577
  "max_models": 5,
578
  "instance": {
579
+ "provider": "sagemaker",
580
+ "instance_type": AUTOTRAIN_MACHINE[selected_task]
581
+ if selected_task in AUTOTRAIN_MACHINE.keys()
582
+ else "p3",
583
  "max_runtime_seconds": 172800,
584
  "num_instances": 1,
585
  "disk_size_gb": 150,
 
605
  "split": 4, # use "auto" split choice in AutoTrain
606
  "col_mapping": col_mapping,
607
  "load_config": {"max_size_bytes": 0, "shuffle": False},
608
+ "dataset_id": selected_dataset,
609
+ "dataset_config": selected_config,
610
+ "dataset_split": selected_split,
611
  }
612
  data_json_resp = http_post(
613
+ path=f"/projects/{project_json_resp['id']}/data/dataset",
614
  payload=data_payload,
615
  token=HF_TOKEN,
616
  domain=AUTOTRAIN_BACKEND_API,
 
 
 
 
 
617
  ).json()
618
  print(f"INFO -- Dataset creation response: {data_json_resp}")
619
  if data_json_resp["download_status"] == 1:
utils.py CHANGED
@@ -19,6 +19,7 @@ AUTOTRAIN_TASK_TO_HUB_TASK = {
19
  "summarization": "summarization",
20
  "image_binary_classification": "image-classification",
21
  "image_multi_class_classification": "image-classification",
 
22
  }
23
 
24
 
@@ -82,7 +83,8 @@ def get_compatible_models(task: str, dataset_ids: List[str]) -> List[str]:
82
  """
83
  compatible_models = []
84
  # Allow any summarization model to be used for summarization tasks
85
- if task == "summarization":
 
86
  model_filter = ModelFilter(
87
  task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
88
  library=["transformers", "pytorch"],
@@ -195,9 +197,11 @@ def create_autotrain_project_name(dataset_id: str, dataset_config: str) -> str:
195
  """Creates an AutoTrain project name for the given dataset ID."""
196
  # Project names cannot have "/", so we need to format community datasets accordingly
197
  dataset_id_formatted = dataset_id.replace("/", "__")
198
- # Project names need to be unique, so we append a random string to guarantee this
199
- project_id = str(uuid.uuid4())[:6]
200
- return f"eval-{dataset_id_formatted}-{dataset_config}-{project_id}"
 
 
201
 
202
 
203
  def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
 
19
  "summarization": "summarization",
20
  "image_binary_classification": "image-classification",
21
  "image_multi_class_classification": "image-classification",
22
+ "text_zero_shot_classification": "text-generation",
23
  }
24
 
25
 
 
83
  """
84
  compatible_models = []
85
  # Allow any summarization model to be used for summarization tasks
86
+ # and allow any text-generation model to be used for text_zero_shot_classification
87
+ if task in ("summarization", "text_zero_shot_classification"):
88
  model_filter = ModelFilter(
89
  task=AUTOTRAIN_TASK_TO_HUB_TASK[task],
90
  library=["transformers", "pytorch"],
 
197
  """Creates an AutoTrain project name for the given dataset ID."""
198
  # Project names cannot have "/", so we need to format community datasets accordingly
199
  dataset_id_formatted = dataset_id.replace("/", "__")
200
+ dataset_config_formatted = dataset_config.replace("--", "__")
201
+ # Project names need to be unique, so we append a random string to guarantee this while adhering to naming rules
202
+ basename = f"eval-{dataset_id_formatted}-{dataset_config_formatted}"
203
+ basename = basename[:60] if len(basename) > 60 else basename # Hub naming limitation
204
+ return f"{basename}-{str(uuid.uuid4())[:6]}"
205
 
206
 
207
  def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]: