Ezi's picture
Upload 312 files
46df0b6
raw
history blame
9.93 kB
import evaluate
import logging
import os
import pandas as pd
import plotly.express as px
import utils
import utils.dataset_utils as ds_utils
from collections import Counter
from os.path import exists, isdir
from os.path import join as pjoin
LABEL_FIELD = "labels"
LABEL_NAMES = "label_names"
LABEL_LIST = "label_list"
LABEL_MEASUREMENT = "label_measurement"
# Specific to the evaluate library
EVAL_LABEL_MEASURE = "label_distribution"
EVAL_LABEL_ID = "labels"
EVAL_LABEL_FRAC = "fractions"
# TODO: This should ideally be in what's returned from the evaluate library
EVAL_LABEL_SUM = "sums"
logs = utils.prepare_logging(__file__)
def map_labels(label_field, ds_name_to_dict, ds_name, config_name):
try:
label_field, label_names = (
ds_name_to_dict[ds_name][config_name]["features"][label_field][0]
if len(
ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0
else ((), [])
)
except KeyError as e:
logs.exception(e)
logs.warning("Not returning a label-name mapping")
return []
return label_names
def make_label_results_dict(label_measurement, label_names):
label_dict = {LABEL_MEASUREMENT: label_measurement,
LABEL_NAMES: label_names}
return label_dict
def make_label_fig(label_results, chart_type="pie"):
try:
label_names = label_results[LABEL_NAMES]
label_measurement = label_results[LABEL_MEASUREMENT]
label_sums = label_measurement[EVAL_LABEL_SUM]
if chart_type == "bar":
fig_labels = plt.bar(
label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_ID],
label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_FRAC])
else:
if chart_type != "pie":
logs.info("Oops! Don't have that chart-type implemented.")
logs.info("Making the default pie chart")
# IMDB - unsupervised has a labels column where all values are -1,
# which breaks the assumption that
# the number of label_names == the number of label_sums.
# This handles that case, assuming it will happen in other datasets.
if len(label_names) != len(label_sums):
logs.warning("Can't make a figure with the given label names: "
"We don't have the right amount of label types "
"to apply them to!")
return False
fig_labels = px.pie(names=label_names, values=label_sums)
except KeyError:
logs.info("Input label data missing required key(s).")
logs.info("We require %s, %s" % (LABEL_NAMES, LABEL_MEASUREMENT))
logs.info("We found: %s" % ",".join(label_results.keys()))
return False
return fig_labels
def extract_label_names(label_field, ds_name, config_name):
ds_name_to_dict = ds_utils.get_dataset_info_dicts(ds_name)
label_names = map_labels(label_field, ds_name_to_dict, ds_name, config_name)
return label_names
class DMTHelper:
"""Helper class for the Data Measurements Tool.
This allows us to keep all variables and functions related to labels
in one file.
"""
def __init__(self, dstats, load_only, save):
logs.info("Initializing labels.")
# -- Data Measurements Tool variables
self.label_results = dstats.label_results
self.fig_labels = dstats.fig_labels
self.use_cache = dstats.use_cache
self.cache_dir = dstats.dataset_cache_dir
self.load_only = load_only
self.save = save
# -- Hugging Face Dataset variables
self.label_field = dstats.label_field
# Input HuggingFace dataset
self.dset = dstats.dset
self.dset_name = dstats.dset_name
self.dset_config = dstats.dset_config
self.label_names = dstats.label_names
# -- Filenames
self.label_dir = "labels"
label_json = "labels.json"
label_fig_json = "labels_fig.json"
label_fig_html = "labels_fig.html"
self.labels_json_fid = pjoin(self.cache_dir, self.label_dir,
label_json)
self.labels_fig_json_fid = pjoin(self.cache_dir, self.label_dir,
label_fig_json)
self.labels_fig_html_fid = pjoin(self.cache_dir, self.label_dir,
label_fig_html)
def run_DMT_processing(self):
"""
Loads or prepares the Labels measurements and figure as specified by
the DMT options.
"""
# First look to see what we can load from cache.
if self.use_cache:
logs.info("Trying to load labels.")
self.fig_labels, self.label_results = self._load_label_cache()
if self.fig_labels:
logs.info("Loaded cached label figure.")
if self.label_results:
logs.info("Loaded cached label results.")
# If we can prepare the results afresh...
if not self.load_only:
# If we didn't load them already, compute label statistics.
if not self.label_results:
logs.info("Preparing labels.")
self.label_results = self._prepare_labels()
# If we didn't load it already, create figure.
if not self.fig_labels:
logs.info("Creating label figure.")
self.fig_labels = \
make_label_fig(self.label_results)
# Finish
if self.save:
self._write_label_cache()
def _load_label_cache(self):
fig_labels = {}
label_results = {}
# Measurements exist. Load them.
if exists(self.labels_json_fid):
# Loads the label list, names, and results
label_results = ds_utils.read_json(self.labels_json_fid)
# Image exists. Load it.
if exists(self.labels_fig_json_fid):
fig_labels = ds_utils.read_plotly(self.labels_fig_json_fid)
return fig_labels, label_results
def _prepare_labels(self):
"""Loads a Labels object and computes label statistics"""
# Label object for the dataset
label_obj = Labels(dataset=self.dset,
dataset_name=self.dset_name,
config_name=self.dset_config)
# TODO: Handle the case where there are multiple label columns.
# The logic throughout the code assumes only one.
if type(self.label_field) == tuple:
label_field = self.label_field[0]
elif type(self.label_field) == str:
label_field = self.label_field
else:
logs.warning("Unexpected format %s for label column name(s). "
"Not computing label statistics." %
type(self.label_field))
return {}
label_results = label_obj.prepare_labels(label_field, self.label_names)
return label_results
def _write_label_cache(self):
ds_utils.make_path(pjoin(self.cache_dir, self.label_dir))
if self.label_results:
ds_utils.write_json(self.label_results, self.labels_json_fid)
if self.fig_labels:
ds_utils.write_plotly(self.fig_labels, self.labels_fig_json_fid)
self.fig_labels.write_html(self.labels_fig_html_fid)
def get_label_filenames(self):
label_fid_dict = {"statistics": self.labels_json_fid,
"figure json": self.labels_fig_json_fid,
"figure html": self.labels_fig_html_fid}
return label_fid_dict
class Labels:
"""Generic class for label processing.
Uses the Dataset to extract the label column and compute label measurements.
"""
def __init__(self, dataset, dataset_name=None, config_name=None):
# Input HuggingFace Dataset.
self.dset = dataset
# These are used to extract label names, when the label names
# are stored in the Dataset object but not in the "label" column
# we are working with, which may instead just be ints corresponding to
# the names
self.ds_name = dataset_name
self.config_name = config_name
# For measurement data and additional metadata.
self.label_results_dict = {}
def prepare_labels(self, label_field, label_names=[]):
""" Uses the evaluate library to return the label distribution. """
logs.info("Inside main label calculation function.")
logs.debug("Looking for label field called '%s'" % label_field)
# The input Dataset object
# When the label field is not found, an error will be thrown.
if label_field in self.dset.features:
label_list = self.dset[label_field]
else:
logs.warning("No label column found -- nothing to do. Returning.")
logs.debug(self.dset.features)
return {}
# Get the evaluate library's measurement for label distro.
label_distribution = evaluate.load(EVAL_LABEL_MEASURE)
# Measure the label distro.
label_measurement = label_distribution.compute(data=label_list)
# TODO: Incorporate this summation into what the evaluate library returns.
label_sum_dict = Counter(label_list)
label_sums = [label_sum_dict[key] for key in sorted(label_sum_dict)]
label_measurement["sums"] = label_sums
if not label_names:
# Have to extract the label names from the Dataset object when the
# actual dataset columns are just ints representing the label names.
label_names = extract_label_names(label_field, self.ds_name,
self.config_name)
label_results = make_label_results_dict(label_measurement, label_names)
return label_results