Spaces:
Runtime error
Runtime error
Add support for image classification
Browse files
app.py
CHANGED
@@ -31,11 +31,12 @@ DATASETS_PREVIEW_API = os.getenv("DATASETS_PREVIEW_API")
|
|
31 |
TASK_TO_ID = {
|
32 |
"binary_classification": 1,
|
33 |
"multi_class_classification": 2,
|
34 |
-
# "multi_label_classification": 3, # Not fully supported in AutoTrain
|
35 |
"entity_extraction": 4,
|
36 |
"extractive_question_answering": 5,
|
37 |
"translation": 6,
|
38 |
"summarization": 8,
|
|
|
|
|
39 |
}
|
40 |
|
41 |
TASK_TO_DEFAULT_METRICS = {
|
@@ -50,8 +51,22 @@ TASK_TO_DEFAULT_METRICS = {
|
|
50 |
"extractive_question_answering": [],
|
51 |
"translation": ["sacrebleu"],
|
52 |
"summarization": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
}
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
56 |
|
57 |
# Extracted from utils.get_supported_metrics
|
@@ -355,6 +370,27 @@ with st.expander("Advanced configuration"):
|
|
355 |
col_mapping[question_col] = "question"
|
356 |
col_mapping[answers_text_col] = "answers.text"
|
357 |
col_mapping[answers_start_col] = "answers.answer_start"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
# Select metrics
|
360 |
st.markdown("**Select metrics**")
|
@@ -408,9 +444,9 @@ with st.form(key="form"):
|
|
408 |
"proj_name": f"eval-project-{project_id}",
|
409 |
"task": TASK_TO_ID[selected_task],
|
410 |
"config": {
|
411 |
-
"language":
|
412 |
-
if selected_task
|
413 |
-
else "
|
414 |
"max_models": 5,
|
415 |
"instance": {
|
416 |
"provider": "aws",
|
|
|
31 |
TASK_TO_ID = {
|
32 |
"binary_classification": 1,
|
33 |
"multi_class_classification": 2,
|
|
|
34 |
"entity_extraction": 4,
|
35 |
"extractive_question_answering": 5,
|
36 |
"translation": 6,
|
37 |
"summarization": 8,
|
38 |
+
"image_binary_classification": 17,
|
39 |
+
"image_multi_class_classification": 18,
|
40 |
}
|
41 |
|
42 |
TASK_TO_DEFAULT_METRICS = {
|
|
|
51 |
"extractive_question_answering": [],
|
52 |
"translation": ["sacrebleu"],
|
53 |
"summarization": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
|
54 |
+
"image_binary_classification": ["f1", "precision", "recall", "auc", "accuracy"],
|
55 |
+
"image_multi_class_classification": [
|
56 |
+
"f1",
|
57 |
+
"precision",
|
58 |
+
"recall",
|
59 |
+
"accuracy",
|
60 |
+
],
|
61 |
}
|
62 |
|
63 |
+
AUTOTRAIN_TASK_TO_LANG = {
|
64 |
+
"translation": "en2de",
|
65 |
+
"image_binary_classification": "unk",
|
66 |
+
"image_multi_class_classification": "unk",
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
SUPPORTED_TASKS = list(TASK_TO_ID.keys())
|
71 |
|
72 |
# Extracted from utils.get_supported_metrics
|
|
|
370 |
col_mapping[question_col] = "question"
|
371 |
col_mapping[answers_text_col] = "answers.text"
|
372 |
col_mapping[answers_start_col] = "answers.answer_start"
|
373 |
+
elif selected_task in ["image_binary_classification", "image_multi_class_classification"]:
|
374 |
+
with col1:
|
375 |
+
st.markdown("`image` column")
|
376 |
+
st.text("")
|
377 |
+
st.text("")
|
378 |
+
st.text("")
|
379 |
+
st.text("")
|
380 |
+
st.markdown("`target` column")
|
381 |
+
with col2:
|
382 |
+
image_col = st.selectbox(
|
383 |
+
"This column should contain the images to be classified",
|
384 |
+
col_names,
|
385 |
+
index=col_names.index(get_key(metadata[0]["col_mapping"], "image")) if metadata is not None else 0,
|
386 |
+
)
|
387 |
+
target_col = st.selectbox(
|
388 |
+
"This column should contain the labels associated with the images",
|
389 |
+
col_names,
|
390 |
+
index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
|
391 |
+
)
|
392 |
+
col_mapping[image_col] = "image"
|
393 |
+
col_mapping[target_col] = "target"
|
394 |
|
395 |
# Select metrics
|
396 |
st.markdown("**Select metrics**")
|
|
|
444 |
"proj_name": f"eval-project-{project_id}",
|
445 |
"task": TASK_TO_ID[selected_task],
|
446 |
"config": {
|
447 |
+
"language": AUTOTRAIN_TASK_TO_LANG[selected_task]
|
448 |
+
if selected_task in AUTOTRAIN_TASK_TO_LANG
|
449 |
+
else "en",
|
450 |
"max_models": 5,
|
451 |
"instance": {
|
452 |
"provider": "aws",
|
utils.py
CHANGED
@@ -11,14 +11,15 @@ from tqdm import tqdm
|
|
11 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
12 |
"binary_classification": "text-classification",
|
13 |
"multi_class_classification": "text-classification",
|
14 |
-
# "multi_label_classification": "text-classification", # Not fully supported in AutoTrain
|
15 |
"entity_extraction": "token-classification",
|
16 |
"extractive_question_answering": "question-answering",
|
17 |
"translation": "translation",
|
18 |
"summarization": "summarization",
|
19 |
-
|
|
|
20 |
}
|
21 |
|
|
|
22 |
HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
|
23 |
LOGS_REPO = "evaluation-job-logs"
|
24 |
|
|
|
11 |
AUTOTRAIN_TASK_TO_HUB_TASK = {
|
12 |
"binary_classification": "text-classification",
|
13 |
"multi_class_classification": "text-classification",
|
|
|
14 |
"entity_extraction": "token-classification",
|
15 |
"extractive_question_answering": "question-answering",
|
16 |
"translation": "translation",
|
17 |
"summarization": "summarization",
|
18 |
+
"image_binary_classification": "image-classification",
|
19 |
+
"image_multi_class_classification": "image-classification",
|
20 |
}
|
21 |
|
22 |
+
|
23 |
HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
|
24 |
LOGS_REPO = "evaluation-job-logs"
|
25 |
|