File size: 3,045 Bytes
a52ecf0 7dd87f6 a52ecf0 7dd87f6 a52ecf0 7dd87f6 a52ecf0 7dd87f6 a52ecf0 7dd87f6 a52ecf0 7dd87f6 a52ecf0 7dd87f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from collections import OrderedDict
from pathlib import Path
from typing import Any
import gradio as gr
class HFDatasetSaverFixed(gr.HuggingFaceDatasetSaver):
def _deserialize_components(
self,
data_dir: Path,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
# Generate the row corresponding to the flagged sample
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
label = component.label or ""
save_dir = data_dir / gr.flagging.client_utils.strip_invalid_filename_characters(label)
save_dir.mkdir(exist_ok=True, parents=True)
deserialized = component.flag(sample, save_dir)
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
try:
assert Path(deserialized).exists()
row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
except (AssertionError, TypeError, ValueError, OSError):
deserialized = "" if deserialized is None else str(deserialized)
row.append(deserialized)
# If component is eligible for a preview, add the URL of the file
# Be mindful that images and audio can be None
if isinstance(component, tuple(file_preview_types)): # type: ignore
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
features[label + " file"] = {"_type": _type}
break
if deserialized:
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
Path(deserialized).relative_to(self.dataset_dir)
).replace("\\", "/")
row.append(
gr.flagging.huggingface_hub.hf_hub_url(
repo_id=self.dataset_id,
filename=path_in_repo,
repo_type="dataset",
)
)
else:
row.append("")
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row
def get_dataset_saver(*args, **kwargs):
return HFDatasetSaverFixed(*args, **kwargs)
|