Spaces:
Runtime error
Runtime error
import evaluate | |
import logging | |
import os | |
import pandas as pd | |
import plotly.express as px | |
import utils | |
import utils.dataset_utils as ds_utils | |
from collections import Counter | |
from os.path import exists, isdir | |
from os.path import join as pjoin | |
LABEL_FIELD = "labels" | |
LABEL_NAMES = "label_names" | |
LABEL_LIST = "label_list" | |
LABEL_MEASUREMENT = "label_measurement" | |
# Specific to the evaluate library | |
EVAL_LABEL_MEASURE = "label_distribution" | |
EVAL_LABEL_ID = "labels" | |
EVAL_LABEL_FRAC = "fractions" | |
# TODO: This should ideally be in what's returned from the evaluate library | |
EVAL_LABEL_SUM = "sums" | |
logs = utils.prepare_logging(__file__) | |
def map_labels(label_field, ds_name_to_dict, ds_name, config_name): | |
try: | |
label_field, label_names = ( | |
ds_name_to_dict[ds_name][config_name]["features"][label_field][0] | |
if len( | |
ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0 | |
else ((), []) | |
) | |
except KeyError as e: | |
logs.exception(e) | |
logs.warning("Not returning a label-name mapping") | |
return [] | |
return label_names | |
def make_label_results_dict(label_measurement, label_names): | |
label_dict = {LABEL_MEASUREMENT: label_measurement, | |
LABEL_NAMES: label_names} | |
return label_dict | |
def make_label_fig(label_results, chart_type="pie"): | |
try: | |
label_names = label_results[LABEL_NAMES] | |
label_measurement = label_results[LABEL_MEASUREMENT] | |
label_sums = label_measurement[EVAL_LABEL_SUM] | |
if chart_type == "bar": | |
fig_labels = plt.bar( | |
label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_ID], | |
label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_FRAC]) | |
else: | |
if chart_type != "pie": | |
logs.info("Oops! Don't have that chart-type implemented.") | |
logs.info("Making the default pie chart") | |
# IMDB - unsupervised has a labels column where all values are -1, | |
# which breaks the assumption that | |
# the number of label_names == the number of label_sums. | |
# This handles that case, assuming it will happen in other datasets. | |
if len(label_names) != len(label_sums): | |
logs.warning("Can't make a figure with the given label names: " | |
"We don't have the right amount of label types " | |
"to apply them to!") | |
return False | |
fig_labels = px.pie(names=label_names, values=label_sums) | |
except KeyError: | |
logs.info("Input label data missing required key(s).") | |
logs.info("We require %s, %s" % (LABEL_NAMES, LABEL_MEASUREMENT)) | |
logs.info("We found: %s" % ",".join(label_results.keys())) | |
return False | |
return fig_labels | |
def extract_label_names(label_field, ds_name, config_name): | |
ds_name_to_dict = ds_utils.get_dataset_info_dicts(ds_name) | |
label_names = map_labels(label_field, ds_name_to_dict, ds_name, config_name) | |
return label_names | |
class DMTHelper: | |
"""Helper class for the Data Measurements Tool. | |
This allows us to keep all variables and functions related to labels | |
in one file. | |
""" | |
def __init__(self, dstats, load_only, save): | |
logs.info("Initializing labels.") | |
# -- Data Measurements Tool variables | |
self.label_results = dstats.label_results | |
self.fig_labels = dstats.fig_labels | |
self.use_cache = dstats.use_cache | |
self.cache_dir = dstats.dataset_cache_dir | |
self.load_only = load_only | |
self.save = save | |
# -- Hugging Face Dataset variables | |
self.label_field = dstats.label_field | |
# Input HuggingFace dataset | |
self.dset = dstats.dset | |
self.dset_name = dstats.dset_name | |
self.dset_config = dstats.dset_config | |
self.label_names = dstats.label_names | |
# -- Filenames | |
self.label_dir = "labels" | |
label_json = "labels.json" | |
label_fig_json = "labels_fig.json" | |
label_fig_html = "labels_fig.html" | |
self.labels_json_fid = pjoin(self.cache_dir, self.label_dir, | |
label_json) | |
self.labels_fig_json_fid = pjoin(self.cache_dir, self.label_dir, | |
label_fig_json) | |
self.labels_fig_html_fid = pjoin(self.cache_dir, self.label_dir, | |
label_fig_html) | |
def run_DMT_processing(self): | |
""" | |
Loads or prepares the Labels measurements and figure as specified by | |
the DMT options. | |
""" | |
# First look to see what we can load from cache. | |
if self.use_cache: | |
logs.info("Trying to load labels.") | |
self.fig_labels, self.label_results = self._load_label_cache() | |
if self.fig_labels: | |
logs.info("Loaded cached label figure.") | |
if self.label_results: | |
logs.info("Loaded cached label results.") | |
# If we can prepare the results afresh... | |
if not self.load_only: | |
# If we didn't load them already, compute label statistics. | |
if not self.label_results: | |
logs.info("Preparing labels.") | |
self.label_results = self._prepare_labels() | |
# If we didn't load it already, create figure. | |
if not self.fig_labels: | |
logs.info("Creating label figure.") | |
self.fig_labels = \ | |
make_label_fig(self.label_results) | |
# Finish | |
if self.save: | |
self._write_label_cache() | |
def _load_label_cache(self): | |
fig_labels = {} | |
label_results = {} | |
# Measurements exist. Load them. | |
if exists(self.labels_json_fid): | |
# Loads the label list, names, and results | |
label_results = ds_utils.read_json(self.labels_json_fid) | |
# Image exists. Load it. | |
if exists(self.labels_fig_json_fid): | |
fig_labels = ds_utils.read_plotly(self.labels_fig_json_fid) | |
return fig_labels, label_results | |
def _prepare_labels(self): | |
"""Loads a Labels object and computes label statistics""" | |
# Label object for the dataset | |
label_obj = Labels(dataset=self.dset, | |
dataset_name=self.dset_name, | |
config_name=self.dset_config) | |
# TODO: Handle the case where there are multiple label columns. | |
# The logic throughout the code assumes only one. | |
if type(self.label_field) == tuple: | |
label_field = self.label_field[0] | |
elif type(self.label_field) == str: | |
label_field = self.label_field | |
else: | |
logs.warning("Unexpected format %s for label column name(s). " | |
"Not computing label statistics." % | |
type(self.label_field)) | |
return {} | |
label_results = label_obj.prepare_labels(label_field, self.label_names) | |
return label_results | |
def _write_label_cache(self): | |
ds_utils.make_path(pjoin(self.cache_dir, self.label_dir)) | |
if self.label_results: | |
ds_utils.write_json(self.label_results, self.labels_json_fid) | |
if self.fig_labels: | |
ds_utils.write_plotly(self.fig_labels, self.labels_fig_json_fid) | |
self.fig_labels.write_html(self.labels_fig_html_fid) | |
def get_label_filenames(self): | |
label_fid_dict = {"statistics": self.labels_json_fid, | |
"figure json": self.labels_fig_json_fid, | |
"figure html": self.labels_fig_html_fid} | |
return label_fid_dict | |
class Labels: | |
"""Generic class for label processing. | |
Uses the Dataset to extract the label column and compute label measurements. | |
""" | |
def __init__(self, dataset, dataset_name=None, config_name=None): | |
# Input HuggingFace Dataset. | |
self.dset = dataset | |
# These are used to extract label names, when the label names | |
# are stored in the Dataset object but not in the "label" column | |
# we are working with, which may instead just be ints corresponding to | |
# the names | |
self.ds_name = dataset_name | |
self.config_name = config_name | |
# For measurement data and additional metadata. | |
self.label_results_dict = {} | |
def prepare_labels(self, label_field, label_names=[]): | |
""" Uses the evaluate library to return the label distribution. """ | |
logs.info("Inside main label calculation function.") | |
logs.debug("Looking for label field called '%s'" % label_field) | |
# The input Dataset object | |
# When the label field is not found, an error will be thrown. | |
if label_field in self.dset.features: | |
label_list = self.dset[label_field] | |
else: | |
logs.warning("No label column found -- nothing to do. Returning.") | |
logs.debug(self.dset.features) | |
return {} | |
# Get the evaluate library's measurement for label distro. | |
label_distribution = evaluate.load(EVAL_LABEL_MEASURE) | |
# Measure the label distro. | |
label_measurement = label_distribution.compute(data=label_list) | |
# TODO: Incorporate this summation into what the evaluate library returns. | |
label_sum_dict = Counter(label_list) | |
label_sums = [label_sum_dict[key] for key in sorted(label_sum_dict)] | |
label_measurement["sums"] = label_sums | |
if not label_names: | |
# Have to extract the label names from the Dataset object when the | |
# actual dataset columns are just ints representing the label names. | |
label_names = extract_label_names(label_field, self.ds_name, | |
self.config_name) | |
label_results = make_label_results_dict(label_measurement, label_names) | |
return label_results | |