|
from collections import OrderedDict |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import gradio as gr |
|
|
|
|
|
def _deserialize_components_fix( |
|
self, |
|
data_dir: Path, |
|
flag_data: list[Any], |
|
flag_option: str = "", |
|
username: str = "", |
|
) -> tuple[dict[Any, Any], list[Any]]: |
|
"""Deserialize components and return the corresponding row for the flagged sample. |
|
|
|
Images/audio are saved to disk as individual files. |
|
""" |
|
|
|
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} |
|
|
|
|
|
features = OrderedDict() |
|
row = [] |
|
for component, sample in zip(self.components, flag_data): |
|
|
|
label = component.label or "" |
|
save_dir = data_dir / gr.flagging.client_utils.strip_invalid_filename_characters(label) |
|
save_dir.mkdir(exist_ok=True, parents=True) |
|
deserialized = component.flag(sample, save_dir) |
|
|
|
|
|
features[label] = {"dtype": "string", "_type": "Value"} |
|
try: |
|
assert Path(deserialized).exists() |
|
row.append(str(Path(deserialized).relative_to(self.dataset_dir))) |
|
except (AssertionError, TypeError, ValueError, OSError): |
|
deserialized = "" if deserialized is None else str(deserialized) |
|
row.append(deserialized) |
|
|
|
|
|
|
|
if isinstance(component, tuple(file_preview_types)): |
|
for _component, _type in file_preview_types.items(): |
|
if isinstance(component, _component): |
|
features[label + " file"] = {"_type": _type} |
|
break |
|
if deserialized: |
|
path_in_repo = str( |
|
Path(deserialized).relative_to(self.dataset_dir) |
|
).replace("\\", "/") |
|
row.append( |
|
gr.flagging.huggingface_hub.hf_hub_url( |
|
repo_id=self.dataset_id, |
|
filename=path_in_repo, |
|
repo_type="dataset", |
|
) |
|
) |
|
else: |
|
row.append("") |
|
features["flag"] = {"dtype": "string", "_type": "Value"} |
|
features["username"] = {"dtype": "string", "_type": "Value"} |
|
row.append(flag_option) |
|
row.append(username) |
|
return features, row |
|
|
|
|
|
def get_dataset_saver(*args, **kwargs): |
|
saver = gr.HuggingFaceDatasetSaver(*args, **kwargs) |
|
saver._deserialize_components = _deserialize_components_fix |
|
return saver |
|
|