commit-message-editing / hf_dataset_saver_builder.py
Petr Tsvetkov
Apply fix to the HF dataset saver
a52ecf0
raw
history blame
2.87 kB
from collections import OrderedDict
from pathlib import Path
from typing import Any
import gradio as gr
def _deserialize_components_fix(
self,
data_dir: Path,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
# Generate the row corresponding to the flagged sample
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
label = component.label or ""
save_dir = data_dir / gr.flagging.client_utils.strip_invalid_filename_characters(label)
save_dir.mkdir(exist_ok=True, parents=True)
deserialized = component.flag(sample, save_dir)
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
try:
assert Path(deserialized).exists()
row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
except (AssertionError, TypeError, ValueError, OSError):
deserialized = "" if deserialized is None else str(deserialized)
row.append(deserialized)
# If component is eligible for a preview, add the URL of the file
# Be mindful that images and audio can be None
if isinstance(component, tuple(file_preview_types)): # type: ignore
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
features[label + " file"] = {"_type": _type}
break
if deserialized:
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
Path(deserialized).relative_to(self.dataset_dir)
).replace("\\", "/")
row.append(
gr.flagging.huggingface_hub.hf_hub_url(
repo_id=self.dataset_id,
filename=path_in_repo,
repo_type="dataset",
)
)
else:
row.append("")
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row
def get_dataset_saver(*args, **kwargs):
saver = gr.HuggingFaceDatasetSaver(*args, **kwargs)
saver._deserialize_components = _deserialize_components_fix
return saver