Spaces:
Runtime error
Runtime error
Handle multiple configs
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ from utils import (
|
|
16 |
create_autotrain_project_name,
|
17 |
format_col_mapping,
|
18 |
get_compatible_models,
|
|
|
19 |
get_dataset_card_url,
|
20 |
get_key,
|
21 |
get_metadata,
|
@@ -123,16 +124,6 @@ SUPPORTED_METRICS = [
|
|
123 |
]
|
124 |
|
125 |
|
126 |
-
def get_config_metadata(config, metadata=None):
|
127 |
-
if metadata is None:
|
128 |
-
return None
|
129 |
-
config_metadata = [m for m in metadata if m["config"] == config]
|
130 |
-
if len(config_metadata) == 1:
|
131 |
-
return config_metadata[0]
|
132 |
-
else:
|
133 |
-
return None
|
134 |
-
|
135 |
-
|
136 |
#######
|
137 |
# APP #
|
138 |
#######
|
@@ -190,10 +181,6 @@ if metadata is None:
|
|
190 |
|
191 |
with st.expander("Advanced configuration"):
|
192 |
# Select task
|
193 |
-
# Hack to filter for unsupported tasks
|
194 |
-
# TODO(lewtun): remove this once we have SQuAD metrics support
|
195 |
-
if metadata is not None and metadata[0]["task_id"] in UNSUPPORTED_TASKS:
|
196 |
-
metadata = None
|
197 |
selected_task = st.selectbox(
|
198 |
"Select a task",
|
199 |
SUPPORTED_TASKS,
|
@@ -211,6 +198,9 @@ with st.expander("Advanced configuration"):
|
|
211 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
212 |
""",
|
213 |
)
|
|
|
|
|
|
|
214 |
|
215 |
# Select splits
|
216 |
splits_resp = http_get(
|
@@ -225,8 +215,8 @@ with st.expander("Advanced configuration"):
|
|
225 |
if split["config"] == selected_config:
|
226 |
split_names.append(split["split"])
|
227 |
|
228 |
-
if
|
229 |
-
eval_split =
|
230 |
else:
|
231 |
eval_split = None
|
232 |
selected_split = st.selectbox(
|
@@ -270,12 +260,16 @@ with st.expander("Advanced configuration"):
|
|
270 |
text_col = st.selectbox(
|
271 |
"This column should contain the text to be classified",
|
272 |
col_names,
|
273 |
-
index=col_names.index(get_key(
|
|
|
|
|
274 |
)
|
275 |
target_col = st.selectbox(
|
276 |
"This column should contain the labels associated with the text",
|
277 |
col_names,
|
278 |
-
index=col_names.index(get_key(
|
|
|
|
|
279 |
)
|
280 |
col_mapping[text_col] = "text"
|
281 |
col_mapping[target_col] = "target"
|
@@ -289,11 +283,13 @@ with st.expander("Advanced configuration"):
|
|
289 |
st.text("")
|
290 |
st.text("")
|
291 |
st.text("")
|
|
|
292 |
st.markdown("`text2` column")
|
293 |
st.text("")
|
294 |
st.text("")
|
295 |
st.text("")
|
296 |
st.text("")
|
|
|
297 |
st.markdown("`target` column")
|
298 |
with col2:
|
299 |
text1_col = st.selectbox(
|
@@ -333,12 +329,16 @@ with st.expander("Advanced configuration"):
|
|
333 |
tokens_col = st.selectbox(
|
334 |
"This column should contain the array of tokens to be classified",
|
335 |
col_names,
|
336 |
-
index=col_names.index(get_key(
|
|
|
|
|
337 |
)
|
338 |
tags_col = st.selectbox(
|
339 |
"This column should contain the labels associated with each part of the text",
|
340 |
col_names,
|
341 |
-
index=col_names.index(get_key(
|
|
|
|
|
342 |
)
|
343 |
col_mapping[tokens_col] = "tokens"
|
344 |
col_mapping[tags_col] = "tags"
|
@@ -355,12 +355,16 @@ with st.expander("Advanced configuration"):
|
|
355 |
text_col = st.selectbox(
|
356 |
"This column should contain the text to be translated",
|
357 |
col_names,
|
358 |
-
index=col_names.index(get_key(
|
|
|
|
|
359 |
)
|
360 |
target_col = st.selectbox(
|
361 |
"This column should contain the target translation",
|
362 |
col_names,
|
363 |
-
index=col_names.index(get_key(
|
|
|
|
|
364 |
)
|
365 |
col_mapping[text_col] = "source"
|
366 |
col_mapping[target_col] = "target"
|
@@ -377,19 +381,23 @@ with st.expander("Advanced configuration"):
|
|
377 |
text_col = st.selectbox(
|
378 |
"This column should contain the text to be summarized",
|
379 |
col_names,
|
380 |
-
index=col_names.index(get_key(
|
|
|
|
|
381 |
)
|
382 |
target_col = st.selectbox(
|
383 |
"This column should contain the target summary",
|
384 |
col_names,
|
385 |
-
index=col_names.index(get_key(
|
|
|
|
|
386 |
)
|
387 |
col_mapping[text_col] = "text"
|
388 |
col_mapping[target_col] = "target"
|
389 |
|
390 |
elif selected_task == "extractive_question_answering":
|
391 |
-
if
|
392 |
-
col_mapping =
|
393 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
394 |
col_mapping = format_col_mapping(col_mapping)
|
395 |
with col1:
|
@@ -413,22 +421,24 @@ with st.expander("Advanced configuration"):
|
|
413 |
context_col = st.selectbox(
|
414 |
"This column should contain the question's context",
|
415 |
col_names,
|
416 |
-
index=col_names.index(get_key(col_mapping, "context")) if
|
417 |
)
|
418 |
question_col = st.selectbox(
|
419 |
"This column should contain the question to be answered, given the context",
|
420 |
col_names,
|
421 |
-
index=col_names.index(get_key(col_mapping, "question")) if
|
422 |
)
|
423 |
answers_text_col = st.selectbox(
|
424 |
"This column should contain example answers to the question, extracted from the context",
|
425 |
col_names,
|
426 |
-
index=col_names.index(get_key(col_mapping, "answers.text")) if
|
427 |
)
|
428 |
answers_start_col = st.selectbox(
|
429 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
430 |
col_names,
|
431 |
-
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
|
|
|
|
432 |
)
|
433 |
col_mapping[context_col] = "context"
|
434 |
col_mapping[question_col] = "question"
|
@@ -446,12 +456,16 @@ with st.expander("Advanced configuration"):
|
|
446 |
image_col = st.selectbox(
|
447 |
"This column should contain the images to be classified",
|
448 |
col_names,
|
449 |
-
index=col_names.index(get_key(
|
|
|
|
|
450 |
)
|
451 |
target_col = st.selectbox(
|
452 |
"This column should contain the labels associated with the images",
|
453 |
col_names,
|
454 |
-
index=col_names.index(get_key(
|
|
|
|
|
455 |
)
|
456 |
col_mapping[image_col] = "image"
|
457 |
col_mapping[target_col] = "target"
|
|
|
16 |
create_autotrain_project_name,
|
17 |
format_col_mapping,
|
18 |
get_compatible_models,
|
19 |
+
get_config_metadata,
|
20 |
get_dataset_card_url,
|
21 |
get_key,
|
22 |
get_metadata,
|
|
|
124 |
]
|
125 |
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
#######
|
128 |
# APP #
|
129 |
#######
|
|
|
181 |
|
182 |
with st.expander("Advanced configuration"):
|
183 |
# Select task
|
|
|
|
|
|
|
|
|
184 |
selected_task = st.selectbox(
|
185 |
"Select a task",
|
186 |
SUPPORTED_TASKS,
|
|
|
198 |
See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
|
199 |
""",
|
200 |
)
|
201 |
+
# Get metadata for config
|
202 |
+
config_metadata = get_config_metadata(selected_config, metadata)
|
203 |
+
print(f"INFO -- Config metadata: {config_metadata}")
|
204 |
|
205 |
# Select splits
|
206 |
splits_resp = http_get(
|
|
|
215 |
if split["config"] == selected_config:
|
216 |
split_names.append(split["split"])
|
217 |
|
218 |
+
if config_metadata is not None:
|
219 |
+
eval_split = config_metadata["splits"].get("eval_split", None)
|
220 |
else:
|
221 |
eval_split = None
|
222 |
selected_split = st.selectbox(
|
|
|
260 |
text_col = st.selectbox(
|
261 |
"This column should contain the text to be classified",
|
262 |
col_names,
|
263 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
264 |
+
if config_metadata is not None
|
265 |
+
else 0,
|
266 |
)
|
267 |
target_col = st.selectbox(
|
268 |
"This column should contain the labels associated with the text",
|
269 |
col_names,
|
270 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
271 |
+
if config_metadata is not None
|
272 |
+
else 0,
|
273 |
)
|
274 |
col_mapping[text_col] = "text"
|
275 |
col_mapping[target_col] = "target"
|
|
|
283 |
st.text("")
|
284 |
st.text("")
|
285 |
st.text("")
|
286 |
+
st.text("")
|
287 |
st.markdown("`text2` column")
|
288 |
st.text("")
|
289 |
st.text("")
|
290 |
st.text("")
|
291 |
st.text("")
|
292 |
+
st.text("")
|
293 |
st.markdown("`target` column")
|
294 |
with col2:
|
295 |
text1_col = st.selectbox(
|
|
|
329 |
tokens_col = st.selectbox(
|
330 |
"This column should contain the array of tokens to be classified",
|
331 |
col_names,
|
332 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tokens"))
|
333 |
+
if config_metadata is not None
|
334 |
+
else 0,
|
335 |
)
|
336 |
tags_col = st.selectbox(
|
337 |
"This column should contain the labels associated with each part of the text",
|
338 |
col_names,
|
339 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "tags"))
|
340 |
+
if config_metadata is not None
|
341 |
+
else 0,
|
342 |
)
|
343 |
col_mapping[tokens_col] = "tokens"
|
344 |
col_mapping[tags_col] = "tags"
|
|
|
355 |
text_col = st.selectbox(
|
356 |
"This column should contain the text to be translated",
|
357 |
col_names,
|
358 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "source"))
|
359 |
+
if config_metadata is not None
|
360 |
+
else 0,
|
361 |
)
|
362 |
target_col = st.selectbox(
|
363 |
"This column should contain the target translation",
|
364 |
col_names,
|
365 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
366 |
+
if config_metadata is not None
|
367 |
+
else 0,
|
368 |
)
|
369 |
col_mapping[text_col] = "source"
|
370 |
col_mapping[target_col] = "target"
|
|
|
381 |
text_col = st.selectbox(
|
382 |
"This column should contain the text to be summarized",
|
383 |
col_names,
|
384 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
|
385 |
+
if config_metadata is not None
|
386 |
+
else 0,
|
387 |
)
|
388 |
target_col = st.selectbox(
|
389 |
"This column should contain the target summary",
|
390 |
col_names,
|
391 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
392 |
+
if config_metadata is not None
|
393 |
+
else 0,
|
394 |
)
|
395 |
col_mapping[text_col] = "text"
|
396 |
col_mapping[target_col] = "target"
|
397 |
|
398 |
elif selected_task == "extractive_question_answering":
|
399 |
+
if config_metadata is not None:
|
400 |
+
col_mapping = config_metadata["col_mapping"]
|
401 |
# Hub YAML parser converts periods to hyphens, so we remap them here
|
402 |
col_mapping = format_col_mapping(col_mapping)
|
403 |
with col1:
|
|
|
421 |
context_col = st.selectbox(
|
422 |
"This column should contain the question's context",
|
423 |
col_names,
|
424 |
+
index=col_names.index(get_key(col_mapping, "context")) if config_metadata is not None else 0,
|
425 |
)
|
426 |
question_col = st.selectbox(
|
427 |
"This column should contain the question to be answered, given the context",
|
428 |
col_names,
|
429 |
+
index=col_names.index(get_key(col_mapping, "question")) if config_metadata is not None else 0,
|
430 |
)
|
431 |
answers_text_col = st.selectbox(
|
432 |
"This column should contain example answers to the question, extracted from the context",
|
433 |
col_names,
|
434 |
+
index=col_names.index(get_key(col_mapping, "answers.text")) if config_metadata is not None else 0,
|
435 |
)
|
436 |
answers_start_col = st.selectbox(
|
437 |
"This column should contain the indices in the context of the first character of each `answers.text`",
|
438 |
col_names,
|
439 |
+
index=col_names.index(get_key(col_mapping, "answers.answer_start"))
|
440 |
+
if config_metadata is not None
|
441 |
+
else 0,
|
442 |
)
|
443 |
col_mapping[context_col] = "context"
|
444 |
col_mapping[question_col] = "question"
|
|
|
456 |
image_col = st.selectbox(
|
457 |
"This column should contain the images to be classified",
|
458 |
col_names,
|
459 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "image"))
|
460 |
+
if config_metadata is not None
|
461 |
+
else 0,
|
462 |
)
|
463 |
target_col = st.selectbox(
|
464 |
"This column should contain the labels associated with the images",
|
465 |
col_names,
|
466 |
+
index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
|
467 |
+
if config_metadata is not None
|
468 |
+
else 0,
|
469 |
)
|
470 |
col_mapping[image_col] = "image"
|
471 |
col_mapping[target_col] = "target"
|
utils.py
CHANGED
@@ -198,3 +198,14 @@ def create_autotrain_project_name(dataset_id: str) -> str:
|
|
198 |
# Project names need to be unique, so we append a random string to guarantee this
|
199 |
project_id = str(uuid.uuid4())[:8]
|
200 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
# Project names need to be unique, so we append a random string to guarantee this
|
199 |
project_id = str(uuid.uuid4())[:8]
|
200 |
return f"eval-project-{dataset_id_formatted}-{project_id}"
|
201 |
+
|
202 |
+
|
203 |
+
def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
|
204 |
+
"""Gets the dataset card metadata for the given config."""
|
205 |
+
if metadata is None:
|
206 |
+
return None
|
207 |
+
config_metadata = [m for m in metadata if m["config"] == config]
|
208 |
+
if len(config_metadata) == 1:
|
209 |
+
return config_metadata[0]
|
210 |
+
else:
|
211 |
+
return None
|