data-measurements-tool / run_data_measurements.py
meg-huggingface
Scripts to generate cache
db74ba9
raw
history blame
9.81 kB
import argparse
import json
import textwrap
from os.path import join as pjoin
from data_measurements import dataset_statistics
from data_measurements import dataset_utils
def load_or_prepare_widgets(ds_args, show_embeddings=False, use_cache=False):
"""
Loader specifically for the widgets used in the app.
Args:
ds_args:
show_embeddings:
use_cache:
Returns:
"""
dstats = dataset_statistics.DatasetStatisticsCacheClass(**ds_args,
use_cache=use_cache)
# Header widget
dstats.load_or_prepare_dset_peek()
# General stats widget
dstats.load_or_prepare_general_stats()
# Labels widget
try:
dstats.set_label_field("label")
dstats.load_or_prepare_labels()
except:
pass
# Text lengths widget
dstats.load_or_prepare_text_lengths()
if show_embeddings:
# Embeddings widget
dstats.load_or_prepare_embeddings()
# Text duplicates widget
dstats.load_or_prepare_text_duplicates()
# nPMI widget
dstats.load_or_prepare_npmi()
npmi_stats = dstats.npmi_stats
# Handling for all pairs; in the UI, people select.
do_npmi(npmi_stats)
# Zipf widget
dstats.load_or_prepare_zipf()
def load_or_prepare(dataset_args, do_html=False, use_cache=False):
all = False
dstats = dataset_statistics.DatasetStatisticsCacheClass(**dataset_args, use_cache=use_cache)
print("Loading dataset.")
dstats.load_or_prepare_dataset()
print("Dataset loaded. Preparing vocab.")
dstats.load_or_prepare_vocab()
print("Vocab prepared.")
if not dataset_args["calculation"]:
all = True
if all or dataset_args["calculation"] == "general":
print("\n* Calculating general statistics.")
dstats.load_or_prepare_general_stats()
print("Done!")
print("Basic text statistics now available at %s." % dstats.general_stats_json_fid)
print(
"Text duplicates now available at %s." % dstats.dup_counts_df_fid
)
if all or dataset_args["calculation"] == "lengths":
print("\n* Calculating text lengths.")
fig_tok_length_fid = pjoin(dstats.cache_path, "lengths_fig.html")
tok_length_json_fid = pjoin(dstats.cache_path, "lengths.json")
dstats.load_or_prepare_text_lengths()
with open(tok_length_json_fid, "w+") as f:
json.dump(dstats.fig_tok_length.to_json(), f)
print("Token lengths now available at %s." % tok_length_json_fid)
if do_html:
dstats.fig_tok_length.write_html(fig_tok_length_fid)
print("Figure saved to %s." % fig_tok_length_fid)
print("Done!")
if all or dataset_args["calculation"] == "labels":
if not dstats.label_field:
print("Warning: You asked for label calculation, but didn't provide "
"the labels field name. Assuming it is 'label'...")
dstats.set_label_field("label")
print("\n* Calculating label distribution.")
dstats.load_or_prepare_labels()
fig_label_html = pjoin(dstats.cache_path, "labels_fig.html")
fig_label_json = pjoin(dstats.cache_path, "labels.json")
dstats.fig_labels.write_html(fig_label_html)
with open(fig_label_json, "w+") as f:
json.dump(dstats.fig_labels.to_json(), f)
print("Done!")
print("Label distribution now available at %s." % dstats.label_dset_fid)
print("Figure saved to %s." % fig_label_html)
if all or dataset_args["calculation"] == "npmi":
print("\n* Preparing nPMI.")
npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
dstats, use_cache=use_cache
)
do_npmi(npmi_stats, use_cache=use_cache)
print("Done!")
print(
"nPMI results now available in %s for all identity terms that "
"occur more than 10 times and all words that "
"co-occur with both terms."
% npmi_stats.pmi_cache_path
)
if all or dataset_args["calculation"] == "zipf":
print("\n* Preparing Zipf.")
zipf_fig_fid = pjoin(dstats.cache_path, "zipf_fig.html")
zipf_json_fid = pjoin(dstats.cache_path, "zipf_fig.json")
dstats.load_or_prepare_zipf()
zipf_fig = dstats.zipf_fig
with open(zipf_json_fid, "w+") as f:
json.dump(zipf_fig.to_json(), f)
zipf_fig.write_html(zipf_fig_fid)
print("Done!")
print("Zipf results now available at %s." % dstats.zipf_fid)
print(
"Figure saved to %s, with corresponding json at %s."
% (zipf_fig_fid, zipf_json_fid)
)
# Don't do this one until someone specifically asks for it -- takes awhile.
if dataset_args["calculation"] == "embeddings":
print("\n* Preparing text embeddings.")
dstats.load_or_prepare_embeddings()
def do_npmi(npmi_stats, use_cache=True):
available_terms = npmi_stats.load_or_prepare_npmi_terms()
completed_pairs = {}
print("Iterating through terms for joint npmi.")
for term1 in available_terms:
for term2 in available_terms:
if term1 != term2:
sorted_terms = tuple(sorted([term1, term2]))
if sorted_terms not in completed_pairs:
term1, term2 = sorted_terms
print("Computing nPMI statistics for %s and %s" % (term1, term2))
_ = npmi_stats.load_or_prepare_joint_npmi(sorted_terms)
completed_pairs[tuple(sorted_terms)] = {}
def get_text_label_df(
ds_name,
config_name,
split_name,
text_field,
label_field,
calculation,
out_dir,
do_html=False,
use_cache=True,
):
if not use_cache:
print("Not using any cache; starting afresh")
ds_name_to_dict = dataset_utils.get_dataset_info_dicts(ds_name)
if label_field:
label_field, label_names = (
ds_name_to_dict[ds_name][config_name]["features"][label_field][0]
if len(ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0
else ((), [])
)
else:
label_field = ()
label_names = []
dataset_args = {
"dset_name": ds_name,
"dset_config": config_name,
"split_name": split_name,
"text_field": text_field,
"label_field": label_field,
"label_names": label_names,
"calculation": calculation,
"cache_dir": out_dir,
}
load_or_prepare_widgets(dataset_args, use_cache=use_cache)
def main():
# TODO: Make this the Hugging Face arg parser
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=textwrap.dedent(
"""
Example for hate speech18 dataset:
python3 run_data_measurements.py --dataset="hate_speech18" --config="default" --split="train" --feature="text"
Example for IMDB dataset:
python3 run_data_measurements.py --dataset="imdb" --config="plain_text" --split="train" --label_field="label" --feature="text"
"""
),
)
parser.add_argument(
"-d", "--dataset", required=True, help="Name of dataset to prepare"
)
parser.add_argument(
"-c", "--config", required=True, help="Dataset configuration to prepare"
)
parser.add_argument(
"-s", "--split", required=True, type=str, help="Dataset split to prepare"
)
parser.add_argument(
"-f",
"--feature",
required=True,
type=str,
default="text",
help="Text column to prepare",
)
parser.add_argument(
"-w",
"--calculation",
help="""What to calculate (defaults to everything except embeddings).\n
Options are:\n
- `general` (for duplicate counts, missing values, length statistics.)\n
- `lengths` for text length distribution\n
- `labels` for label distribution\n
- `embeddings` (Warning: Slow.)\n
- `npmi` for word associations\n
- `zipf` for zipfian statistics
""",
)
parser.add_argument(
"-l",
"--label_field",
type=str,
required=False,
default="",
help="Field name for label column in dataset (Required if there is a label field that you want information about)",
)
parser.add_argument(
"--cached",
default=False,
required=False,
action="store_true",
help="Whether to use cached files (Optional)",
)
parser.add_argument(
"--do_html",
default=False,
required=False,
action="store_true",
help="Whether to write out corresponding HTML files (Optional)",
)
parser.add_argument("--out_dir", default="cache_dir", help="Where to write out to.")
args = parser.parse_args()
print("Proceeding with the following arguments:")
print(args)
# run_data_measurements.py -n hate_speech18 -c default -s train -f text -w npmi
get_text_label_df(
args.dataset,
args.config,
args.split,
args.feature,
args.label_field,
args.calculation,
args.out_dir,
do_html=args.do_html,
use_cache=args.cached,
)
print()
if __name__ == "__main__":
main()