Spaces:
Runtime error
Runtime error
import argparse | |
import json | |
from dotenv import load_dotenv | |
import plotly | |
import shutil | |
import smtplib | |
import ssl | |
import sys | |
import textwrap | |
from data_measurements import dataset_statistics | |
from data_measurements.zipf import zipf | |
from huggingface_hub import create_repo, Repository, hf_api | |
from os import getenv | |
from os.path import exists, join as pjoin | |
from pathlib import Path | |
import utils | |
from utils import dataset_utils | |
logs = utils.prepare_logging(__file__) | |
def load_or_prepare_widgets(ds_args, show_embeddings=False, | |
show_perplexities=False, use_cache=False): | |
""" | |
Loader specifically for the widgets used in the app. | |
Args: | |
ds_args: | |
show_embeddings: | |
show_perplexities: | |
use_cache: | |
Returns: | |
""" | |
dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args, use_cache=use_cache) | |
# Header widget | |
dstats.load_or_prepare_dset_peek() | |
# General stats widget | |
dstats.load_or_prepare_general_stats() | |
# Labels widget | |
dstats.load_or_prepare_labels() | |
# Text lengths widget | |
dstats.load_or_prepare_text_lengths() | |
if show_embeddings: | |
# Embeddings widget | |
dstats.load_or_prepare_embeddings() | |
if show_perplexities: | |
# Text perplexities widget | |
dstats.load_or_prepare_text_perplexities() | |
# Text duplicates widget | |
dstats.load_or_prepare_text_duplicates() | |
# nPMI widget | |
dstats.load_or_prepare_npmi() | |
# Zipf widget | |
dstats.load_or_prepare_zipf() | |
def load_or_prepare(dataset_args, calculation=False, use_cache=False): | |
# TODO: Catch error exceptions for each measurement, so that an error | |
# for one measurement doesn't break the calculation of all of them. | |
do_all = False | |
dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, | |
use_cache=use_cache) | |
logs.info("Tokenizing dataset.") | |
dstats.load_or_prepare_tokenized_df() | |
logs.info("Calculating vocab.") | |
dstats.load_or_prepare_vocab() | |
if not calculation: | |
do_all = True | |
if do_all or calculation == "general": | |
logs.info("\n* Calculating general statistics.") | |
dstats.load_or_prepare_general_stats() | |
logs.info("Done!") | |
logs.info( | |
"Basic text statistics now available at %s." % dstats.general_stats_json_fid) | |
if do_all or calculation == "duplicates": | |
logs.info("\n* Calculating text duplicates.") | |
dstats.load_or_prepare_text_duplicates() | |
duplicates_fid_dict = dstats.duplicates_files | |
logs.info("If all went well, then results are in the following files:") | |
for key, value in duplicates_fid_dict.items(): | |
logs.info("%s: %s" % (key, value)) | |
if do_all or calculation == "lengths": | |
logs.info("\n* Calculating text lengths.") | |
dstats.load_or_prepare_text_lengths() | |
length_fid_dict = dstats.length_obj.get_filenames() | |
print("If all went well, then results are in the following files:") | |
for key, value in length_fid_dict.items(): | |
print("%s: %s" % (key, value)) | |
print() | |
if do_all or calculation == "labels": | |
logs.info("\n* Calculating label statistics.") | |
if dstats.label_field not in dstats.dset.features: | |
logs.warning("No label field found.") | |
logs.info("No label statistics to calculate.") | |
else: | |
dstats.load_or_prepare_labels() | |
npmi_fid_dict = dstats.label_files | |
print("If all went well, then results are in the following files:") | |
for key, value in npmi_fid_dict.items(): | |
print("%s: %s" % (key, value)) | |
print() | |
if do_all or calculation == "npmi": | |
print("\n* Preparing nPMI.") | |
dstats.load_or_prepare_npmi() | |
npmi_fid_dict = dstats.npmi_files | |
print("If all went well, then results are in the following files:") | |
for key, value in npmi_fid_dict.items(): | |
if isinstance(value, dict): | |
print(key + ":") | |
for key2, value2 in value.items(): | |
print("\t%s: %s" % (key2, value2)) | |
else: | |
print("%s: %s" % (key, value)) | |
print() | |
if do_all or calculation == "zipf": | |
logs.info("\n* Preparing Zipf.") | |
dstats.load_or_prepare_zipf() | |
logs.info("Done!") | |
zipf_json_fid, zipf_fig_json_fid, zipf_fig_html_fid = zipf.get_zipf_fids( | |
dstats.dataset_cache_dir) | |
logs.info("Zipf results now available at %s." % zipf_json_fid) | |
logs.info( | |
"Figure saved to %s, with corresponding json at %s." | |
% (zipf_fig_html_fid, zipf_fig_json_fid) | |
) | |
# Don't do this one until someone specifically asks for it -- takes awhile. | |
if calculation == "embeddings": | |
logs.info("\n* Preparing text embeddings.") | |
dstats.load_or_prepare_embeddings() | |
# Don't do this one until someone specifically asks for it -- takes awhile. | |
if calculation == "perplexities": | |
logs.info("\n* Preparing text perplexities.") | |
dstats.load_or_prepare_text_perplexities() | |
def pass_args_to_DMT(dset_name, dset_config, split_name, text_field, label_field, label_names, calculation, dataset_cache_dir, prepare_gui=False, use_cache=True): | |
if not use_cache: | |
logs.info("Not using any cache; starting afresh") | |
dataset_args = { | |
"dset_name": dset_name, | |
"dset_config": dset_config, | |
"split_name": split_name, | |
"text_field": text_field, | |
"label_field": label_field, | |
"label_names": label_names, | |
"dataset_cache_dir": dataset_cache_dir | |
} | |
if prepare_gui: | |
load_or_prepare_widgets(dataset_args, use_cache=use_cache) | |
else: | |
load_or_prepare(dataset_args, calculation=calculation, use_cache=use_cache) | |
def set_defaults(args): | |
if not args.config: | |
args.config = "default" | |
logs.info("Config name not specified. Assuming it's 'default'.") | |
if not args.split: | |
args.split = "train" | |
logs.info("Split name not specified. Assuming it's 'train'.") | |
if not args.feature: | |
args.feature = "text" | |
logs.info("Text column name not given. Assuming it's 'text'.") | |
if not args.label_field: | |
args.label_field = "label" | |
logs.info("Label column name not given. Assuming it's 'label'.") | |
return args | |
def main(): | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
description=textwrap.dedent( | |
""" | |
Example for hate speech18 dataset: | |
python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text" | |
Example for IMDB dataset: | |
python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text" | |
""" | |
), | |
) | |
parser.add_argument( | |
"-d", "--dataset", required=True, help="Name of dataset to prepare" | |
) | |
parser.add_argument( | |
"-c", "--config", required=False, default="", help="Dataset configuration to prepare" | |
) | |
parser.add_argument( | |
"-s", "--split", required=False, default="", type=str, | |
help="Dataset split to prepare" | |
) | |
parser.add_argument( | |
"-f", | |
"--feature", | |
"-t", | |
"--text-field", | |
required=False, | |
nargs="+", | |
type=str, | |
default="", | |
help="Column to prepare (handled as text)", | |
) | |
parser.add_argument( | |
"-w", | |
"--calculation", | |
help="""What to calculate (defaults to everything except embeddings and perplexities).\n | |
Options are:\n | |
- `general` (for duplicate counts, missing values, length statistics.)\n | |
- `duplicates` for duplicate counts\n | |
- `lengths` for text length distribution\n | |
- `labels` for label distribution\n | |
- `embeddings` (Warning: Slow.)\n | |
- `perplexities` (Warning: Slow.)\n | |
- `npmi` for word associations\n | |
- `zipf` for zipfian statistics | |
""", | |
) | |
parser.add_argument( | |
"-l", | |
"--label_field", | |
type=str, | |
required=False, | |
default="", | |
help="Field name for label column in dataset (Required if there is a label field that you want information about)", | |
) | |
parser.add_argument('-n', '--label_names', nargs='+', default=[]) | |
parser.add_argument( | |
"--use_cache", | |
default=False, | |
required=False, | |
action="store_true", | |
help="Whether to use cached files (Optional)", | |
) | |
parser.add_argument("--out_dir", default="cache_dir", | |
help="Where to write out to.") | |
parser.add_argument( | |
"--overwrite_previous", | |
default=False, | |
required=False, | |
action="store_true", | |
help="Whether to overwrite a previous local cache for these same arguments (Optional)", | |
) | |
parser.add_argument( | |
"--email", | |
default=None, | |
help="An email that recieves a message about whether the computation was successful. If email is not None, then you must have EMAIL_PASSWORD=<your email password> for the sender email ([email protected]) in a file named .env at the root of this repo.") | |
parser.add_argument( | |
"--push_cache_to_hub", | |
default=False, | |
required=False, | |
action="store_true", | |
help="Whether to push the cache to an organization on the hub. If you are using this option, you must have HUB_CACHE_ORGANIZATION=<the organization you've set up on the hub to store your cache> and HF_TOKEN=<your hf token> on separate lines in a file named .env at the root of this repo.", | |
) | |
parser.add_argument("--prepare_GUI_data", default=False, required=False, | |
action="store_true", | |
help="Use this to process all of the stats used in the GUI.") | |
parser.add_argument("--keep_local", default=True, required=False, | |
action="store_true", | |
help="Whether to save the data locally.") | |
orig_args = parser.parse_args() | |
args = set_defaults(orig_args) | |
logs.info("Proceeding with the following arguments:") | |
logs.info(args) | |
# run_data_measurements.py -d hate_speech18 -c default -s train -f text -w npmi | |
if args.email is not None: | |
if Path(".env").is_file(): | |
load_dotenv(".env") | |
EMAIL_PASSWORD = getenv("EMAIL_PASSWORD") | |
context = ssl.create_default_context() | |
port = 465 | |
server = smtplib.SMTP_SSL("smtp.gmail.com", port, context=context) | |
server.login("[email protected]", EMAIL_PASSWORD) | |
dataset_cache_name, local_dataset_cache_dir = dataset_utils.get_cache_dir_naming(args.out_dir, args.dataset, args.config, args.split, args.feature) | |
if not args.use_cache and exists(local_dataset_cache_dir): | |
if args.overwrite_previous: | |
shutil.rmtree(local_dataset_cache_dir) | |
else: | |
raise OSError("Cached results for this dataset already exist at %s. " | |
"Delete it or use the --overwrite_previous argument." % local_dataset_cache_dir) | |
# Initialize the local cache directory | |
dataset_utils.make_path(local_dataset_cache_dir) | |
# Initialize the repository | |
# TODO: print out local or hub cache directory location. | |
if args.push_cache_to_hub: | |
repo = dataset_utils.initialize_cache_hub_repo(local_dataset_cache_dir, dataset_cache_name) | |
# Run the measurements. | |
try: | |
pass_args_to_DMT( | |
dset_name=args.dataset, | |
dset_config=args.config, | |
split_name=args.split, | |
text_field=args.feature, | |
label_field=args.label_field, | |
label_names=args.label_names, | |
calculation=args.calculation, | |
dataset_cache_dir=local_dataset_cache_dir, | |
prepare_gui=args.prepare_GUI_data, | |
use_cache=args.use_cache, | |
) | |
if args.push_cache_to_hub: | |
repo.push_to_hub(commit_message="Added dataset cache.") | |
computed_message = f"Data measurements have been computed for dataset" \ | |
f" with these arguments: {args}." | |
logs.info(computed_message) | |
if args.email is not None: | |
computed_message += "\nYou can return to the data measurements tool " \ | |
"to view them." | |
server.sendmail("[email protected]", args.email, | |
"Subject: Data Measurements Computed!\n\n" + computed_message) | |
logs.info(computed_message) | |
except Exception as e: | |
logs.exception(e) | |
error_message = f"An error occurred in computing data measurements " \ | |
f"for dataset with arguments: {args}. " \ | |
f"Feel free to make an issue here: " \ | |
f"https://github.com/huggingface/data-measurements-tool/issues" | |
if args.email is not None: | |
server.sendmail("[email protected]", args.email, | |
"Subject: Data Measurements not Computed\n\n" + error_message) | |
logs.warning("Data measurements not computed. ☹️") | |
logs.warning(error_message) | |
return | |
if not args.keep_local: | |
# Remove the dataset from local storage - we only want it stored on the hub. | |
logs.warning("Deleting measurements data locally at %s" % local_dataset_cache_dir) | |
shutil.rmtree(local_dataset_cache_dir) | |
else: | |
logs.info("Measurements made available locally at %s" % local_dataset_cache_dir) | |
if __name__ == "__main__": | |
main() | |