File size: 9,928 Bytes
46df0b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import evaluate
import logging
import os
import pandas as pd
import plotly.express as px
import utils
import utils.dataset_utils as ds_utils
from collections import Counter
from os.path import exists, isdir
from os.path import join as pjoin

LABEL_FIELD = "labels"
LABEL_NAMES = "label_names"
LABEL_LIST = "label_list"
LABEL_MEASUREMENT = "label_measurement"
# Specific to the evaluate library
EVAL_LABEL_MEASURE = "label_distribution"
EVAL_LABEL_ID = "labels"
EVAL_LABEL_FRAC = "fractions"
# TODO: This should ideally be in what's returned from the evaluate library
EVAL_LABEL_SUM = "sums"

logs = utils.prepare_logging(__file__)


def map_labels(label_field, ds_name_to_dict, ds_name, config_name):
    try:
        label_field, label_names = (
            ds_name_to_dict[ds_name][config_name]["features"][label_field][0]
            if len(
                ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0
            else ((), [])
        )
    except KeyError as e:
        logs.exception(e)
        logs.warning("Not returning a label-name mapping")
        return []
    return label_names


def make_label_results_dict(label_measurement, label_names):
    label_dict = {LABEL_MEASUREMENT: label_measurement,
                  LABEL_NAMES: label_names}
    return label_dict


def make_label_fig(label_results, chart_type="pie"):
    try:
        label_names = label_results[LABEL_NAMES]
        label_measurement = label_results[LABEL_MEASUREMENT]
        label_sums = label_measurement[EVAL_LABEL_SUM]
        if chart_type == "bar":
            fig_labels = plt.bar(
                label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_ID],
                label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_FRAC])
        else:
            if chart_type != "pie":
                logs.info("Oops! Don't have that chart-type implemented.")
                logs.info("Making the default pie chart")
            # IMDB - unsupervised has a labels column where all values are -1,
            # which breaks the assumption that
            # the number of label_names == the number of label_sums.
            # This handles that case, assuming it will happen in other datasets.
            if len(label_names) != len(label_sums):
                logs.warning("Can't make a figure with the given label names: "
                             "We don't have the right amount of label types "
                             "to apply them to!")
                return False
            fig_labels = px.pie(names=label_names, values=label_sums)
    except KeyError:
        logs.info("Input label data missing required key(s).")
        logs.info("We require %s, %s" % (LABEL_NAMES, LABEL_MEASUREMENT))
        logs.info("We found: %s" % ",".join(label_results.keys()))
        return False
    return fig_labels


def extract_label_names(label_field, ds_name, config_name):
    ds_name_to_dict = ds_utils.get_dataset_info_dicts(ds_name)
    label_names = map_labels(label_field, ds_name_to_dict, ds_name, config_name)
    return label_names


class DMTHelper:
    """Helper class for the Data Measurements Tool.
    This allows us to keep all variables and functions related to labels
    in one file.
    """

    def __init__(self, dstats, load_only, save):
        logs.info("Initializing labels.")
        # -- Data Measurements Tool variables
        self.label_results = dstats.label_results
        self.fig_labels = dstats.fig_labels
        self.use_cache = dstats.use_cache
        self.cache_dir = dstats.dataset_cache_dir
        self.load_only = load_only
        self.save = save
        # -- Hugging Face Dataset variables
        self.label_field = dstats.label_field
        # Input HuggingFace dataset
        self.dset = dstats.dset
        self.dset_name = dstats.dset_name
        self.dset_config = dstats.dset_config
        self.label_names = dstats.label_names
        # -- Filenames
        self.label_dir = "labels"
        label_json = "labels.json"
        label_fig_json = "labels_fig.json"
        label_fig_html = "labels_fig.html"
        self.labels_json_fid = pjoin(self.cache_dir, self.label_dir,
                                     label_json)
        self.labels_fig_json_fid = pjoin(self.cache_dir, self.label_dir,
                                         label_fig_json)
        self.labels_fig_html_fid = pjoin(self.cache_dir, self.label_dir,
                                         label_fig_html)

    def run_DMT_processing(self):
        """
        Loads or prepares the Labels measurements and figure as specified by
        the DMT options.
        """
        # First look to see what we can load from cache.
        if self.use_cache:
            logs.info("Trying to load labels.")
            self.fig_labels, self.label_results = self._load_label_cache()
            if self.fig_labels:
                logs.info("Loaded cached label figure.")
            if self.label_results:
                logs.info("Loaded cached label results.")
        # If we can prepare the results afresh...
        if not self.load_only:
            # If we didn't load them already, compute label statistics.
            if not self.label_results:
                logs.info("Preparing labels.")
                self.label_results = self._prepare_labels()
            # If we didn't load it already, create figure.
            if not self.fig_labels:
                logs.info("Creating label figure.")
                self.fig_labels = \
                    make_label_fig(self.label_results)
            # Finish
            if self.save:
                self._write_label_cache()

    def _load_label_cache(self):
        fig_labels = {}
        label_results = {}
        # Measurements exist. Load them.
        if exists(self.labels_json_fid):
            # Loads the label list, names, and results
            label_results = ds_utils.read_json(self.labels_json_fid)
        # Image exists. Load it.
        if exists(self.labels_fig_json_fid):
            fig_labels = ds_utils.read_plotly(self.labels_fig_json_fid)
        return fig_labels, label_results

    def _prepare_labels(self):
        """Loads a Labels object and computes label statistics"""
        # Label object for the dataset
        label_obj = Labels(dataset=self.dset,
                           dataset_name=self.dset_name,
                           config_name=self.dset_config)
        # TODO: Handle the case where there are multiple label columns.
        # The logic throughout the code assumes only one.
        if type(self.label_field) == tuple:
            label_field = self.label_field[0]
        elif type(self.label_field) == str:
            label_field = self.label_field
        else:
            logs.warning("Unexpected format %s for label column name(s). "
                         "Not computing label statistics." %
                         type(self.label_field))
            return {}
        label_results = label_obj.prepare_labels(label_field, self.label_names)
        return label_results

    def _write_label_cache(self):
        ds_utils.make_path(pjoin(self.cache_dir, self.label_dir))
        if self.label_results:
            ds_utils.write_json(self.label_results, self.labels_json_fid)
        if self.fig_labels:
            ds_utils.write_plotly(self.fig_labels, self.labels_fig_json_fid)
            self.fig_labels.write_html(self.labels_fig_html_fid)

    def get_label_filenames(self):
        label_fid_dict = {"statistics": self.labels_json_fid,
                          "figure json": self.labels_fig_json_fid,
                          "figure html": self.labels_fig_html_fid}
        return label_fid_dict


class Labels:
    """Generic class for label processing.
    Uses the Dataset to extract the label column and compute label measurements.
    """

    def __init__(self, dataset, dataset_name=None, config_name=None):
        # Input HuggingFace Dataset.
        self.dset = dataset
        # These are used to extract label names, when the label names
        # are stored in the Dataset object but not in the "label" column
        # we are working with, which may instead just be ints corresponding to
        # the names
        self.ds_name = dataset_name
        self.config_name = config_name
        # For measurement data and additional metadata.
        self.label_results_dict = {}

    def prepare_labels(self, label_field, label_names=[]):
        """ Uses the evaluate library to return the label distribution. """
        logs.info("Inside main label calculation function.")
        logs.debug("Looking for label field called '%s'" % label_field)
        # The input Dataset object
        # When the label field is not found, an error will be thrown.
        if label_field in self.dset.features:
            label_list = self.dset[label_field]
        else:
            logs.warning("No label column found -- nothing to do. Returning.")
            logs.debug(self.dset.features)
            return {}
        # Get the evaluate library's measurement for label distro.
        label_distribution = evaluate.load(EVAL_LABEL_MEASURE)
        # Measure the label distro.
        label_measurement = label_distribution.compute(data=label_list)
        # TODO: Incorporate this summation into what the evaluate library returns.
        label_sum_dict = Counter(label_list)
        label_sums = [label_sum_dict[key] for key in sorted(label_sum_dict)]
        label_measurement["sums"] = label_sums
        if not label_names:
            # Have to extract the label names from the Dataset object when the
            # actual dataset columns are just ints representing the label names.
            label_names = extract_label_names(label_field, self.ds_name,
                                              self.config_name)
        label_results = make_label_results_dict(label_measurement, label_names)
        return label_results