Spaces:
Runtime error
Runtime error
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
import pandas as pd | |
import plotly | |
import pyarrow.feather as feather | |
import utils | |
from dataclasses import asdict | |
from datasets import Dataset, get_dataset_infos, load_dataset, load_from_disk, \ | |
NamedSplit | |
from dotenv import load_dotenv | |
from huggingface_hub import Repository, list_datasets | |
from json2html import * | |
from os import getenv | |
from os.path import exists, isdir, join as pjoin | |
from pathlib import Path | |
# treating inf values as NaN as well | |
pd.set_option("use_inf_as_na", True) | |
## String names used in Hugging Face dataset configs. | |
HF_FEATURE_FIELD = "features" | |
HF_LABEL_FIELD = "label" | |
HF_DESC_FIELD = "description" | |
CACHE_DIR = "cache_dir" | |
## String names we are using within this code. | |
# These are not coming from the stored dataset nor HF config, | |
# but rather used as identifiers in our dicts and dataframes. | |
TEXT_FIELD = "text" | |
PERPLEXITY_FIELD = "perplexity" | |
TOKENIZED_FIELD = "tokenized_text" | |
EMBEDDING_FIELD = "embedding" | |
LENGTH_FIELD = "length" | |
VOCAB = "vocab" | |
WORD = "word" | |
CNT = "count" | |
PROP = "proportion" | |
TEXT_NAN_CNT = "text_nan_count" | |
TXT_LEN = "text lengths" | |
TOT_WORDS = "total words" | |
TOT_OPEN_WORDS = "total open words" | |
_DATASET_LIST = [ | |
"c4", | |
"squad", | |
"squad_v2", | |
"hate_speech18", | |
"hate_speech_offensive", | |
"glue", | |
"super_glue", | |
"wikitext", | |
"imdb", | |
"HuggingFaceM4/OBELICS", | |
] | |
_STREAMABLE_DATASET_LIST = [ | |
"c4", | |
"wikitext", | |
"HuggingFaceM4/OBELICS", | |
] | |
_MAX_ROWS = 2000 | |
logs = utils.prepare_logging(__file__) | |
def _load_dotenv_for_cache_on_hub(): | |
""" | |
This function loads and returns the organization name that you've set up on the | |
hub for storing your data measurements cache on the hub. It also loads the associated | |
access token. It expects you to have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> | |
and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo. | |
Returns: | |
tuple of strings: hub_cache_organization, hf_token | |
""" | |
if Path(".env").is_file(): | |
load_dotenv(".env") | |
hf_token = getenv("HF_TOKEN") | |
hub_cache_organization = getenv("HUB_CACHE_ORGANIZATION") | |
return hub_cache_organization, hf_token | |
def get_cache_dir_naming(out_dir, dataset, config, split, feature): | |
feature_text = hyphenated(feature) | |
dataset_cache_name = f"{dataset}_{config}_{split}_{feature_text}" | |
local_dataset_cache_dir = out_dir + "/" + dataset_cache_name | |
return dataset_cache_name, local_dataset_cache_dir | |
def initialize_cache_hub_repo(local_cache_dir, dataset_cache_name): | |
""" | |
This function tries to initialize a dataset cache on the huggingface hub. The | |
function expects you to have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> | |
and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo. | |
Args: | |
local_cache_dir (string): | |
The path to the local dataset cache. | |
dataset_cache_name (string): | |
The name of the dataset repo on the huggingface hub that you want. | |
""" | |
hub_cache_organization, hf_token = _load_dotenv_for_cache_on_hub() | |
clone_source = pjoin(hub_cache_organization, dataset_cache_name) | |
repo = Repository(local_dir=local_cache_dir, | |
clone_from=clone_source, | |
repo_type="dataset", use_auth_token=hf_token) | |
repo.lfs_track(["*.feather"]) | |
return repo | |
def pull_cache_from_hub(cache_path, dataset_cache_dir): | |
""" | |
This function tries to pull a datasets cache from the huggingface hub if a | |
cache for the dataset does not already exist locally. The function expects you | |
to have you HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> | |
and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo. | |
Args: | |
cache_path (string): | |
The path to the local dataset cache that you want. | |
dataset_cache_dir (string): | |
The name of the dataset repo on the huggingface hub. | |
""" | |
hub_cache_organization, hf_token = _load_dotenv_for_cache_on_hub() | |
clone_source = pjoin(hub_cache_organization, dataset_cache_dir) | |
if isdir(cache_path): | |
logs.warning("Already a local cache for the dataset, so not pulling from the hub.") | |
else: | |
# Here, dataset_info.id is of the form: <hub cache organization>/<dataset cache dir> | |
if dataset_cache_dir in [ | |
dataset_info.id.split("/")[-1] for dataset_info in | |
list_datasets(author=hub_cache_organization, | |
use_auth_token=hf_token)]: | |
Repository(local_dir=cache_path, | |
clone_from=clone_source, | |
repo_type="dataset", use_auth_token=hf_token) | |
logs.info("Pulled cache from hub!") | |
else: | |
logs.warning("Asking to pull cache from hub but cannot find cached repo on the hub.") | |
def load_truncated_dataset( | |
dataset_name, | |
config_name, | |
split_name, | |
num_rows=_MAX_ROWS, | |
use_cache=True, | |
cache_dir=CACHE_DIR, | |
use_streaming=True, | |
save=True, | |
): | |
""" | |
This function loads the first `num_rows` items of a dataset for a | |
given `config_name` and `split_name`. | |
If `use_cache` and `cache_name` exists, the truncated dataset is loaded from | |
`cache_name`. | |
Otherwise, a new truncated dataset is created and immediately saved | |
to `cache_name`. | |
When the dataset is streamable, we iterate through the first | |
`num_rows` examples in streaming mode, write them to a jsonl file, | |
then create a new dataset from the json. | |
This is the most direct way to make a Dataset from an IterableDataset | |
as of datasets version 1.6.1. | |
Otherwise, we download the full dataset and select the first | |
`num_rows` items | |
Args: | |
dataset_name (string): | |
dataset id in the dataset library | |
config_name (string): | |
dataset configuration | |
split_name (string): | |
split name | |
num_rows (int) [optional]: | |
number of rows to truncate the dataset to | |
cache_dir (string): | |
name of the cache directory | |
use_cache (bool): | |
whether to load from the cache if it exists | |
use_streaming (bool): | |
whether to use streaming when the dataset supports it | |
save (bool): | |
whether to save the dataset locally | |
Returns: | |
Dataset: the (truncated if specified) dataset as a Dataset object | |
""" | |
logs.info("Loading or preparing dataset saved in %s " % cache_dir) | |
if use_cache and exists(cache_dir): | |
dataset = load_from_disk(cache_dir) | |
else: | |
if use_streaming and dataset_name in _STREAMABLE_DATASET_LIST: | |
iterable_dataset = load_dataset( | |
dataset_name, | |
name=config_name, | |
split=split_name, | |
streaming=True, | |
).take(num_rows) | |
rows = list(iterable_dataset) | |
def gen(): | |
yield from rows | |
dataset = Dataset.from_generator(gen, features=iterable_dataset.features) | |
dataset._split = NamedSplit(split_name) | |
# f = open("temp.jsonl", "w", encoding="utf-8") | |
# for row in rows: | |
# _ = f.write(json.dumps(row) + "\n") | |
# f.close() | |
# dataset = Dataset.from_json( | |
# "temp.jsonl", features=iterable_dataset.features, split=NamedSplit(split_name) | |
# ) | |
else: | |
full_dataset = load_dataset( | |
dataset_name, | |
name=config_name, | |
split=split_name, | |
) | |
if len(full_dataset) >= num_rows: | |
dataset = full_dataset.select(range(num_rows)) | |
# Make the directory name clear that it's not the full dataset. | |
cache_dir = pjoin(cache_dir, ("_%s" % num_rows)) | |
else: | |
dataset = full_dataset | |
if save: | |
dataset.save_to_disk(cache_dir) | |
return dataset | |
def hyphenated(features): | |
"""When multiple features are asked for, hyphenate them together when they're used for filenames or titles""" | |
return '-'.join(features) | |
def get_typed_features(features, ftype="string", parents=None): | |
""" | |
Recursively get a list of all features of a certain dtype | |
:param features: | |
:param ftype: | |
:param parents: | |
:return: a list of tuples > e.g. ('A', 'B', 'C') for feature example['A']['B']['C'] | |
""" | |
if parents is None: | |
parents = [] | |
typed_features = [] | |
for name, feat in features.items(): | |
if isinstance(feat, dict): | |
if feat.get("dtype", None) == ftype or feat.get("feature", {}).get( | |
("dtype", None) == ftype | |
): | |
typed_features += [tuple(parents + [name])] | |
elif "feature" in feat: | |
if feat["feature"].get("dtype", None) == ftype: | |
typed_features += [tuple(parents + [name])] | |
elif isinstance(feat["feature"], dict): | |
typed_features += get_typed_features( | |
feat["feature"], ftype, parents + [name] | |
) | |
else: | |
for k, v in feat.items(): | |
if isinstance(v, dict): | |
typed_features += get_typed_features( | |
v, ftype, parents + [name, k] | |
) | |
elif name == "dtype" and feat == ftype: | |
typed_features += [tuple(parents)] | |
return typed_features | |
def get_label_features(features, parents=None): | |
""" | |
Recursively get a list of all features that are ClassLabels | |
:param features: | |
:param parents: | |
:return: pairs of tuples as above and the list of class names | |
""" | |
if parents is None: | |
parents = [] | |
label_features = [] | |
for name, feat in features.items(): | |
if isinstance(feat, dict): | |
if "names" in feat: | |
label_features += [(tuple(parents + [name]), feat["names"])] | |
elif "feature" in feat: | |
if "names" in feat: | |
label_features += [ | |
(tuple(parents + [name]), feat["feature"]["names"]) | |
] | |
elif isinstance(feat["feature"], dict): | |
label_features += get_label_features( | |
feat["feature"], parents + [name] | |
) | |
else: | |
for k, v in feat.items(): | |
if isinstance(v, dict): | |
label_features += get_label_features(v, parents + [name, k]) | |
elif name == "names": | |
label_features += [(tuple(parents), feat)] | |
return label_features | |
# get the info we need for the app sidebar in dict format | |
def dictionarize_info(dset_info): | |
info_dict = asdict(dset_info) | |
res = { | |
"config_name": info_dict["config_name"], | |
"splits": { | |
spl: 100 | |
for spl, spl_info in info_dict["splits"].items() | |
}, | |
"features": { | |
"string": get_typed_features(info_dict["features"], "string"), | |
"int32": get_typed_features(info_dict["features"], "int32"), | |
"float32": get_typed_features(info_dict["features"], "float32"), | |
"label": get_label_features(info_dict["features"]), | |
}, | |
"description": dset_info.description, | |
} | |
return res | |
def get_dataset_info_dicts(dataset_id=None): | |
""" | |
Creates a dict from dataset configs. | |
Uses the datasets lib's get_dataset_infos | |
:return: Dictionary mapping dataset names to their configurations | |
""" | |
if dataset_id is not None: | |
ds_name_to_conf_dict = { | |
dataset_id: { | |
config_name: dictionarize_info(config_info) | |
for config_name, config_info in get_dataset_infos(dataset_id).items() | |
} | |
} | |
else: | |
ds_name_to_conf_dict = { | |
ds_id: { | |
config_name: dictionarize_info(config_info) | |
for config_name, config_info in get_dataset_infos(ds_id).items() | |
} | |
for ds_id in _DATASET_LIST | |
} | |
return ds_name_to_conf_dict | |
# get all instances of a specific field in a dataset | |
def extract_field(examples, field_path, new_field_name=None): | |
if new_field_name is None: | |
new_field_name = "_".join(field_path) | |
field_list = [] | |
# TODO: Breaks the CLI if this isn't checked. | |
if isinstance(field_path, str): | |
field_path = [field_path] | |
item_list = examples[field_path[0]] | |
for field_name in field_path[1:]: | |
item_list = [ | |
next_item | |
for item in item_list | |
for next_item in ( | |
item[field_name] | |
if isinstance(item[field_name], list) | |
else [item[field_name]] | |
) | |
] | |
field_list += [ | |
field | |
for item in item_list | |
for field in (item if isinstance(item, list) else [item]) | |
] | |
return {new_field_name: field_list} | |
def make_path(path): | |
os.makedirs(path, exist_ok=True) | |
def counter_dict_to_df(dict_input, key_as_column=False): | |
df_output = pd.DataFrame(dict_input, index=[0]).T | |
if key_as_column: | |
df_output.reset_index(inplace=True) | |
df_output.columns = ["instance", "count"] | |
else: | |
df_output.columns = ["count"] | |
return df_output.sort_values(by="count", ascending=False) | |
def write_plotly(fig, fid): | |
write_json(plotly.io.to_json(fig), fid) | |
def read_plotly(fid): | |
fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8"))) | |
return fig | |
def write_json_as_html(input_json, html_fid): | |
html_dict = json2html.convert(json=input_json) | |
with open(html_fid, "w+") as f: | |
f.write(html_dict) | |
def df_to_write_html(input_df, html_fid): | |
"""Writes a dataframe to an HTML file""" | |
input_df.to_HTML(html_fid) | |
def read_df(df_fid): | |
return pd.DataFrame.from_dict(read_json(df_fid), orient="index") | |
def write_df(df, df_fid): | |
"""In order to preserve the index of our dataframes, we can't | |
use the compressed pandas dataframe file format .feather. | |
There's a preference for json amongst HF devs, so we use that here.""" | |
df_dict = df.to_dict('index') | |
write_json(df_dict, df_fid) | |
def write_json(json_dict, json_fid): | |
with open(json_fid, "w", encoding="utf-8") as f: | |
json.dump(json_dict, f) | |
def read_json(json_fid): | |
json_dict = json.load(open(json_fid, encoding="utf-8")) | |
return json_dict |