meg-huggingface commited on
Commit
e8ac901
1 Parent(s): 2981bb2

Merging back dataset statistics

Browse files
data_measurements/dataset_statistics.py ADDED
@@ -0,0 +1,1313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import statistics
18
+ import torch
19
+ from os import mkdir
20
+ from os.path import exists, isdir
21
+ from os.path import join as pjoin
22
+
23
+ import nltk
24
+ import numpy as np
25
+ import pandas as pd
26
+ import plotly
27
+ import plotly.express as px
28
+ import plotly.figure_factory as ff
29
+ import plotly.graph_objects as go
30
+ import pyarrow.feather as feather
31
+ import matplotlib.pyplot as plt
32
+ import matplotlib.image as mpimg
33
+ import seaborn as sns
34
+ from datasets import load_from_disk
35
+ from nltk.corpus import stopwords
36
+ from sklearn.feature_extraction.text import CountVectorizer
37
+
38
+ from .dataset_utils import (
39
+ TOT_WORDS,
40
+ TOT_OPEN_WORDS,
41
+ CNT,
42
+ DEDUP_TOT,
43
+ EMBEDDING_FIELD,
44
+ LENGTH_FIELD,
45
+ OUR_LABEL_FIELD,
46
+ OUR_TEXT_FIELD,
47
+ PROP,
48
+ TEXT_NAN_CNT,
49
+ TOKENIZED_FIELD,
50
+ TXT_LEN,
51
+ VOCAB,
52
+ WORD,
53
+ extract_field,
54
+ load_truncated_dataset,
55
+ )
56
+ from .embeddings import Embeddings
57
+ from .npmi import nPMI
58
+ from .zipf import Zipf
59
+
60
+ pd.options.display.float_format = "{:,.3f}".format
61
+
62
+ logs = logging.getLogger(__name__)
63
+ logs.setLevel(logging.WARNING)
64
+ logs.propagate = False
65
+
66
+ if not logs.handlers:
67
+
68
+ # Logging info to log file
69
+ file = logging.FileHandler("./log_files/dataset_statistics.log")
70
+ fileformat = logging.Formatter("%(asctime)s:%(message)s")
71
+ file.setLevel(logging.INFO)
72
+ file.setFormatter(fileformat)
73
+
74
+ # Logging debug messages to stream
75
+ stream = logging.StreamHandler()
76
+ streamformat = logging.Formatter("[data_measurements_tool] %(message)s")
77
+ stream.setLevel(logging.WARNING)
78
+ stream.setFormatter(streamformat)
79
+
80
+ logs.addHandler(file)
81
+ logs.addHandler(stream)
82
+
83
+
84
+ # TODO: Read this in depending on chosen language / expand beyond english
85
+ nltk.download("stopwords")
86
+ _CLOSED_CLASS = (
87
+ stopwords.words("english")
88
+ + [
89
+ "t",
90
+ "n",
91
+ "ll",
92
+ "d",
93
+ "wasn",
94
+ "weren",
95
+ "won",
96
+ "aren",
97
+ "wouldn",
98
+ "shouldn",
99
+ "didn",
100
+ "don",
101
+ "hasn",
102
+ "ain",
103
+ "couldn",
104
+ "doesn",
105
+ "hadn",
106
+ "haven",
107
+ "isn",
108
+ "mightn",
109
+ "mustn",
110
+ "needn",
111
+ "shan",
112
+ "would",
113
+ "could",
114
+ "dont",
115
+ "u",
116
+ ]
117
+ + [str(i) for i in range(0, 21)]
118
+ )
119
+ _IDENTITY_TERMS = [
120
+ "man",
121
+ "woman",
122
+ "non-binary",
123
+ "gay",
124
+ "lesbian",
125
+ "queer",
126
+ "trans",
127
+ "straight",
128
+ "cis",
129
+ "she",
130
+ "her",
131
+ "hers",
132
+ "he",
133
+ "him",
134
+ "his",
135
+ "they",
136
+ "them",
137
+ "their",
138
+ "theirs",
139
+ "himself",
140
+ "herself",
141
+ ]
142
+ # treating inf values as NaN as well
143
+ pd.set_option("use_inf_as_na", True)
144
+
145
+ _MIN_VOCAB_COUNT = 10
146
+ _TREE_DEPTH = 12
147
+ _TREE_MIN_NODES = 250
148
+ # as long as we're using sklearn - already pushing the resources
149
+ _MAX_CLUSTER_EXAMPLES = 5000
150
+ _NUM_VOCAB_BATCHES = 2000
151
+ _TOP_N = 100
152
+ _CVEC = CountVectorizer(token_pattern="(?u)\\b\\w+\\b", lowercase=True)
153
+
154
+ class DatasetStatisticsCacheClass:
155
+ def __init__(
156
+ self,
157
+ cache_dir,
158
+ dset_name,
159
+ dset_config,
160
+ split_name,
161
+ text_field,
162
+ label_field,
163
+ label_names,
164
+ calculation=None,
165
+ use_cache=False,
166
+ ):
167
+ # This is only used for standalone runs for each kind of measurement.
168
+ self.calculation = calculation
169
+ self.our_text_field = OUR_TEXT_FIELD
170
+ self.our_length_field = LENGTH_FIELD
171
+ self.our_label_field = OUR_LABEL_FIELD
172
+ self.our_tokenized_field = TOKENIZED_FIELD
173
+ self.our_embedding_field = EMBEDDING_FIELD
174
+ self.cache_dir = cache_dir
175
+ # Use stored data if there; otherwise calculate afresh
176
+ self.use_cache = use_cache
177
+ ### What are we analyzing?
178
+ # name of the Hugging Face dataset
179
+ self.dset_name = dset_name
180
+ # name of the dataset config
181
+ self.dset_config = dset_config
182
+ # name of the split to analyze
183
+ self.split_name = split_name
184
+ # TODO: Chould this be "feature" ?
185
+ # which text fields are we analysing?
186
+ self.text_field = text_field
187
+ # which label fields are we analysing?
188
+ self.label_field = label_field
189
+ # what are the names of the classes?
190
+ self.label_names = label_names
191
+ ## Hugging Face dataset objects
192
+ self.dset = None # original dataset
193
+ # HF dataset with all of the self.text_field instances in self.dset
194
+ self.text_dset = None
195
+ self.dset_peek = None
196
+ # HF dataset with text embeddings in the same order as self.text_dset
197
+ self.embeddings_dset = None
198
+ # HF dataset with all of the self.label_field instances in self.dset
199
+ self.label_dset = None
200
+ ## Data frames
201
+ # Tokenized text
202
+ self.tokenized_df = None
203
+ # save sentence length histogram in the class so it doesn't ge re-computed
204
+ self.length_df = None
205
+ self.fig_tok_length = None
206
+ # Data Frame version of self.label_dset
207
+ self.label_df = None
208
+ # save label pie chart in the class so it doesn't ge re-computed
209
+ self.fig_labels = None
210
+ # Vocabulary with word counts in the dataset
211
+ self.vocab_counts_df = None
212
+ # Vocabulary filtered to remove stopwords
213
+ self.vocab_counts_filtered_df = None
214
+ self.sorted_top_vocab_df = None
215
+ ## General statistics and duplicates
216
+ self.total_words = 0
217
+ self.total_open_words = 0
218
+ # Number of NaN values (NOT empty strings)
219
+ self.text_nan_count = 0
220
+ # Number of text items that appear more than once in the dataset
221
+ self.dedup_total = 0
222
+ # Duplicated text items along with their number of occurences ("count")
223
+ self.dup_counts_df = None
224
+ self.avg_length = None
225
+ self.std_length = None
226
+ self.general_stats_dict = None
227
+ self.num_uniq_lengths = 0
228
+ # clustering text by embeddings
229
+ # the hierarchical clustering tree is represented as a list of nodes,
230
+ # the first is the root
231
+ self.node_list = []
232
+ # save tree figure in the class so it doesn't ge re-computed
233
+ self.fig_tree = None
234
+ # keep Embeddings object around to explore clusters
235
+ self.embeddings = None
236
+ # nPMI
237
+ # Holds a nPMIStatisticsCacheClass object
238
+ self.npmi_stats = None
239
+ # TODO: Have lowercase be an option for a user to set.
240
+ self.to_lowercase = True
241
+ # The minimum amount of times a word should occur to be included in
242
+ # word-count-based calculations (currently just relevant to nPMI)
243
+ self.min_vocab_count = _MIN_VOCAB_COUNT
244
+ # zipf
245
+ self.z = None
246
+ self.zipf_fig = None
247
+ self.cvec = _CVEC
248
+ # File definitions
249
+ # path to the directory used for caching
250
+ if not isinstance(text_field, str):
251
+ text_field = "-".join(text_field)
252
+ #if isinstance(label_field, str):
253
+ # label_field = label_field
254
+ #else:
255
+ # label_field = "-".join(label_field)
256
+ self.cache_path = pjoin(
257
+ self.cache_dir,
258
+ f"{dset_name}_{dset_config}_{split_name}_{text_field}", #{label_field},
259
+ )
260
+ if not isdir(self.cache_path):
261
+ logs.warning("Creating cache directory %s." % self.cache_path)
262
+ mkdir(self.cache_path)
263
+
264
+ # Cache files not needed for UI
265
+ self.dset_fid = pjoin(self.cache_path, "base_dset")
266
+ self.tokenized_df_fid = pjoin(self.cache_path, "tokenized_df.feather")
267
+ self.label_dset_fid = pjoin(self.cache_path, "label_dset")
268
+
269
+ # Needed for UI -- embeddings
270
+ self.text_dset_fid = pjoin(self.cache_path, "text_dset")
271
+ # Needed for UI
272
+ self.dset_peek_json_fid = pjoin(self.cache_path, "dset_peek.json")
273
+
274
+ ## Label cache files.
275
+ # Needed for UI
276
+ self.fig_labels_json_fid = pjoin(self.cache_path, "fig_labels.json")
277
+
278
+ ## Length cache files
279
+ # Needed for UI
280
+ self.length_df_fid = pjoin(self.cache_path, "length_df.feather")
281
+ # Needed for UI
282
+ self.length_stats_json_fid = pjoin(self.cache_path, "length_stats.json")
283
+ self.vocab_counts_df_fid = pjoin(self.cache_path, "vocab_counts.feather")
284
+ # Needed for UI
285
+ self.dup_counts_df_fid = pjoin(self.cache_path, "dup_counts_df.feather")
286
+ # Needed for UI
287
+ self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.json")
288
+
289
+ ## General text stats
290
+ # Needed for UI
291
+ self.general_stats_json_fid = pjoin(self.cache_path, "general_stats_dict.json")
292
+ # Needed for UI
293
+ self.sorted_top_vocab_df_fid = pjoin(self.cache_path,
294
+ "sorted_top_vocab.feather")
295
+ ## Zipf cache files
296
+ # Needed for UI
297
+ self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
298
+ # Needed for UI
299
+ self.zipf_fig_fid = pjoin(self.cache_path, "zipf_fig.json")
300
+
301
+ ## Embeddings cache files
302
+ # Needed for UI
303
+ self.node_list_fid = pjoin(self.cache_path, "node_list.th")
304
+ # Needed for UI
305
+ self.fig_tree_json_fid = pjoin(self.cache_path, "fig_tree.json")
306
+ self.zipf_counts = None
307
+
308
+ self.live = False
309
+
310
+ def set_deployment(self, live=True):
311
+ """
312
+ Function that we can hit when we deploy, so that cache files are not
313
+ written out/recalculated, but instead that part of the UI can be punted.
314
+ """
315
+ self.live = live
316
+
317
+ def get_base_dataset(self):
318
+ """Gets a pointer to the truncated base dataset object."""
319
+ if not self.dset:
320
+ self.dset = load_truncated_dataset(
321
+ self.dset_name,
322
+ self.dset_config,
323
+ self.split_name,
324
+ cache_name=self.dset_fid,
325
+ use_cache=True,
326
+ use_streaming=True,
327
+ )
328
+
329
+ def load_or_prepare_general_stats(self, save=True):
330
+ """
331
+ Content for expander_general_stats widget.
332
+ Provides statistics for total words, total open words,
333
+ the sorted top vocab, the NaN count, and the duplicate count.
334
+ Args:
335
+
336
+ Returns:
337
+
338
+ """
339
+ # General statistics
340
+ if (
341
+ self.use_cache
342
+ and exists(self.general_stats_json_fid)
343
+ and exists(self.dup_counts_df_fid)
344
+ and exists(self.sorted_top_vocab_df_fid)
345
+ ):
346
+ logs.info('Loading cached general stats')
347
+ self.load_general_stats()
348
+ else:
349
+ if not self.live:
350
+ logs.info('Preparing general stats')
351
+ self.prepare_general_stats()
352
+ if save:
353
+ write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
354
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
355
+ write_json(self.general_stats_dict, self.general_stats_json_fid)
356
+
357
+
358
+ def load_or_prepare_text_lengths(self, save=True):
359
+ """
360
+ The text length widget relies on this function, which provides
361
+ a figure of the text lengths, some text length statistics, and
362
+ a text length dataframe to peruse.
363
+ Args:
364
+ save:
365
+ Returns:
366
+
367
+ """
368
+ # Text length figure
369
+ if (self.use_cache and exists(self.fig_tok_length_fid)):
370
+ self.fig_tok_length_png = mpimg.imread(self.fig_tok_length_fid)
371
+ self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
372
+ else:
373
+ if not self.live:
374
+ self.prepare_fig_text_lengths()
375
+ if save:
376
+ write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
377
+
378
+ # Text length dataframe
379
+ if self.use_cache and exists(self.length_df_fid):
380
+ self.length_df = feather.read_feather(self.length_df_fid)
381
+ else:
382
+ if not self.live:
383
+ self.prepare_length_df()
384
+ if save:
385
+ write_df(self.length_df, self.length_df_fid)
386
+
387
+ # Text length stats.
388
+ if self.use_cache and exists(self.length_stats_json_fid):
389
+ with open(self.length_stats_json_fid, "r") as f:
390
+ self.length_stats_dict = json.load(f)
391
+ self.avg_length = self.length_stats_dict["avg length"]
392
+ self.std_length = self.length_stats_dict["std length"]
393
+ self.num_uniq_lengths = self.length_stats_dict["num lengths"]
394
+ else:
395
+ if not self.live:
396
+ self.prepare_text_length_stats()
397
+ if save:
398
+ write_json(self.length_stats_dict, self.length_stats_json_fid)
399
+
400
+ def prepare_length_df(self):
401
+ if not self.live:
402
+ if self.tokenized_df is None:
403
+ self.tokenized_df = self.do_tokenization()
404
+ self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[
405
+ TOKENIZED_FIELD].apply(len)
406
+ self.length_df = self.tokenized_df[
407
+ [LENGTH_FIELD, OUR_TEXT_FIELD]].sort_values(
408
+ by=[LENGTH_FIELD], ascending=True
409
+ )
410
+
411
+ def prepare_text_length_stats(self):
412
+ if not self.live:
413
+ if self.tokenized_df is None or LENGTH_FIELD not in self.tokenized_df.columns or self.length_df is None:
414
+ self.prepare_length_df()
415
+ avg_length = sum(self.tokenized_df[LENGTH_FIELD])/len(self.tokenized_df[LENGTH_FIELD])
416
+ self.avg_length = round(avg_length, 1)
417
+ std_length = statistics.stdev(self.tokenized_df[LENGTH_FIELD])
418
+ self.std_length = round(std_length, 1)
419
+ self.num_uniq_lengths = len(self.length_df["length"].unique())
420
+ self.length_stats_dict = {"avg length": self.avg_length,
421
+ "std length": self.std_length,
422
+ "num lengths": self.num_uniq_lengths}
423
+
424
+ def prepare_fig_text_lengths(self):
425
+ if not self.live:
426
+ if self.tokenized_df is None or LENGTH_FIELD not in self.tokenized_df.columns:
427
+ self.prepare_length_df()
428
+ self.fig_tok_length = make_fig_lengths(self.tokenized_df, LENGTH_FIELD)
429
+
430
+ def load_or_prepare_embeddings(self, save=True):
431
+ if self.use_cache and exists(self.node_list_fid) and exists(self.fig_tree_json_fid):
432
+ self.node_list = torch.load(self.node_list_fid)
433
+ self.fig_tree = read_plotly(self.fig_tree_json_fid)
434
+ elif self.use_cache and exists(self.node_list_fid):
435
+ self.node_list = torch.load(self.node_list_fid)
436
+ self.fig_tree = make_tree_plot(self.node_list,
437
+ self.text_dset)
438
+ if save:
439
+ write_plotly(self.fig_tree, self.fig_tree_json_fid)
440
+ else:
441
+ self.embeddings = Embeddings(self, use_cache=self.use_cache)
442
+ self.embeddings.make_hierarchical_clustering()
443
+ self.node_list = self.embeddings.node_list
444
+ self.fig_tree = make_tree_plot(self.node_list,
445
+ self.text_dset)
446
+ if save:
447
+ torch.save(self.node_list, self.node_list_fid)
448
+ write_plotly(self.fig_tree, self.fig_tree_json_fid)
449
+
450
+ # get vocab with word counts
451
+ def load_or_prepare_vocab(self, save=True):
452
+ """
453
+ Calculates the vocabulary count from the tokenized text.
454
+ The resulting dataframes may be used in nPMI calculations, zipf, etc.
455
+ :param
456
+ :return:
457
+ """
458
+ if (
459
+ self.use_cache
460
+ and exists(self.vocab_counts_df_fid)
461
+ ):
462
+ logs.info("Reading vocab from cache")
463
+ self.load_vocab()
464
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
465
+ else:
466
+ logs.info("Calculating vocab afresh")
467
+ if len(self.tokenized_df) == 0:
468
+ self.tokenized_df = self.do_tokenization()
469
+ if save:
470
+ logs.info("Writing out.")
471
+ write_df(self.tokenized_df, self.tokenized_df_fid)
472
+ word_count_df = count_vocab_frequencies(self.tokenized_df)
473
+ logs.info("Making dfs with proportion.")
474
+ self.vocab_counts_df = calc_p_word(word_count_df)
475
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
476
+ if save:
477
+ logs.info("Writing out.")
478
+ write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
479
+ logs.info("unfiltered vocab")
480
+ logs.info(self.vocab_counts_df)
481
+ logs.info("filtered vocab")
482
+ logs.info(self.vocab_counts_filtered_df)
483
+
484
+ def load_vocab(self):
485
+ with open(self.vocab_counts_df_fid, "rb") as f:
486
+ self.vocab_counts_df = feather.read_feather(f)
487
+ # Handling for changes in how the index is saved.
488
+ self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
489
+
490
+ def load_or_prepare_text_duplicates(self, save=True):
491
+ if self.use_cache and exists(self.dup_counts_df_fid):
492
+ with open(self.dup_counts_df_fid, "rb") as f:
493
+ self.dup_counts_df = feather.read_feather(f)
494
+ elif self.dup_counts_df is None:
495
+ if not self.live:
496
+ self.prepare_text_duplicates()
497
+ if save:
498
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
499
+ else:
500
+ if not self.live:
501
+ # This happens when self.dup_counts_df is already defined;
502
+ # This happens when general_statistics were calculated first,
503
+ # since general statistics requires the number of duplicates
504
+ if save:
505
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
506
+
507
+ def load_general_stats(self):
508
+ self.general_stats_dict = json.load(open(self.general_stats_json_fid, encoding="utf-8"))
509
+ with open(self.sorted_top_vocab_df_fid, "rb") as f:
510
+ self.sorted_top_vocab_df = feather.read_feather(f)
511
+ self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
512
+ self.dedup_total = self.general_stats_dict[DEDUP_TOT]
513
+ self.total_words = self.general_stats_dict[TOT_WORDS]
514
+ self.total_open_words = self.general_stats_dict[TOT_OPEN_WORDS]
515
+
516
+ def prepare_general_stats(self):
517
+ if not self.live:
518
+ if self.tokenized_df is None:
519
+ logs.warning("Tokenized dataset not yet loaded; doing so.")
520
+ self.load_or_prepare_dataset()
521
+ if self.vocab_counts_df is None:
522
+ logs.warning("Vocab not yet loaded; doing so.")
523
+ self.load_or_prepare_vocab()
524
+ self.sorted_top_vocab_df = self.vocab_counts_filtered_df.sort_values(
525
+ "count", ascending=False
526
+ ).head(_TOP_N)
527
+ self.total_words = len(self.vocab_counts_df)
528
+ self.total_open_words = len(self.vocab_counts_filtered_df)
529
+ self.text_nan_count = int(self.tokenized_df.isnull().sum().sum())
530
+ self.prepare_text_duplicates()
531
+ self.dedup_total = sum(self.dup_counts_df[CNT])
532
+ self.general_stats_dict = {
533
+ TOT_WORDS: self.total_words,
534
+ TOT_OPEN_WORDS: self.total_open_words,
535
+ TEXT_NAN_CNT: self.text_nan_count,
536
+ DEDUP_TOT: self.dedup_total,
537
+ }
538
+
539
+ def prepare_text_duplicates(self):
540
+ if not self.live:
541
+ if self.tokenized_df is None:
542
+ self.load_or_prepare_tokenized_df()
543
+ dup_df = self.tokenized_df[
544
+ self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
545
+ self.dup_counts_df = pd.DataFrame(
546
+ dup_df.pivot_table(
547
+ columns=[OUR_TEXT_FIELD], aggfunc="size"
548
+ ).sort_values(ascending=False),
549
+ columns=[CNT],
550
+ )
551
+ self.dup_counts_df[OUR_TEXT_FIELD] = self.dup_counts_df.index.copy()
552
+
553
+ def load_or_prepare_dataset(self, save=True):
554
+ """
555
+ Prepares the HF datasets and data frames containing the untokenized and
556
+ tokenized text as well as the label values.
557
+ self.tokenized_df is used further for calculating text lengths,
558
+ word counts, etc.
559
+ Args:
560
+ save: Store the calculated data to disk.
561
+
562
+ Returns:
563
+
564
+ """
565
+ logs.info("Doing text dset.")
566
+ self.load_or_prepare_text_dset(save)
567
+ logs.info("Doing tokenized dataframe")
568
+ self.load_or_prepare_tokenized_df(save)
569
+ logs.info("Doing dataset peek")
570
+ self.load_or_prepare_dset_peek(save)
571
+
572
+ def load_or_prepare_dset_peek(self, save=True):
573
+ if self.use_cache and exists(self.dset_peek_json_fid):
574
+ with open(self.dset_peek_json_fid, "r") as f:
575
+ self.dset_peek = json.load(f)["dset peek"]
576
+ else:
577
+ if self.dset is None:
578
+ self.get_base_dataset()
579
+ self.dset_peek = self.dset[:100]
580
+ if save:
581
+ write_json({"dset peek": self.dset_peek}, self.dset_peek_json_fid)
582
+
583
+ def load_or_prepare_tokenized_df(self, save=True):
584
+ if (self.use_cache and exists(self.tokenized_df_fid)):
585
+ self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
586
+ else:
587
+ if not self.live:
588
+ # tokenize all text instances
589
+ self.tokenized_df = self.do_tokenization()
590
+ if save:
591
+ logs.warning("Saving tokenized dataset to disk")
592
+ # save tokenized text
593
+ write_df(self.tokenized_df, self.tokenized_df_fid)
594
+
595
+ def load_or_prepare_text_dset(self, save=True):
596
+ if (self.use_cache and exists(self.text_dset_fid)):
597
+ # load extracted text
598
+ self.text_dset = load_from_disk(self.text_dset_fid)
599
+ logs.warning("Loaded dataset from disk")
600
+ logs.info(self.text_dset)
601
+ # ...Or load it from the server and store it anew
602
+ else:
603
+ if not self.live:
604
+ self.prepare_text_dset()
605
+ if save:
606
+ # save extracted text instances
607
+ logs.warning("Saving dataset to disk")
608
+ self.text_dset.save_to_disk(self.text_dset_fid)
609
+
610
+ def prepare_text_dset(self):
611
+ if not self.live:
612
+ self.get_base_dataset()
613
+ # extract all text instances
614
+ self.text_dset = self.dset.map(
615
+ lambda examples: extract_field(
616
+ examples, self.text_field, OUR_TEXT_FIELD
617
+ ),
618
+ batched=True,
619
+ remove_columns=list(self.dset.features),
620
+ )
621
+
622
+ def do_tokenization(self):
623
+ """
624
+ Tokenizes the dataset
625
+ :return:
626
+ """
627
+ if self.text_dset is None:
628
+ self.load_or_prepare_text_dset()
629
+ sent_tokenizer = self.cvec.build_tokenizer()
630
+
631
+ def tokenize_batch(examples):
632
+ # TODO: lowercase should be an option
633
+ res = {
634
+ TOKENIZED_FIELD: [
635
+ tuple(sent_tokenizer(text.lower()))
636
+ for text in examples[OUR_TEXT_FIELD]
637
+ ]
638
+ }
639
+ res[LENGTH_FIELD] = [len(tok_text) for tok_text in res[TOKENIZED_FIELD]]
640
+ return res
641
+
642
+ tokenized_dset = self.text_dset.map(
643
+ tokenize_batch,
644
+ batched=True,
645
+ # remove_columns=[OUR_TEXT_FIELD], keep around to print
646
+ )
647
+ tokenized_df = pd.DataFrame(tokenized_dset)
648
+ return tokenized_df
649
+
650
+ def set_label_field(self, label_field="label"):
651
+ """
652
+ Setter for label_field. Used in the CLI when a user asks for information
653
+ about labels, but does not specify the field;
654
+ 'label' is assumed as a default.
655
+ """
656
+ self.label_field = label_field
657
+
658
+ def load_or_prepare_labels(self, save=True):
659
+ # TODO: This is in a transitory state for creating fig cache.
660
+ # Clean up to be caching and reading everything correctly.
661
+ """
662
+ Extracts labels from the Dataset
663
+ :return:
664
+ """
665
+ # extracted labels
666
+ if len(self.label_field) > 0:
667
+ if self.use_cache and exists(self.fig_labels_json_fid):
668
+ self.fig_labels = read_plotly(self.fig_labels_json_fid)
669
+ elif self.use_cache and exists(self.label_dset_fid):
670
+ # load extracted labels
671
+ self.label_dset = load_from_disk(self.label_dset_fid)
672
+ self.label_df = self.label_dset.to_pandas()
673
+ self.fig_labels = make_fig_labels(
674
+ self.label_df, self.label_names, OUR_LABEL_FIELD
675
+ )
676
+ if save:
677
+ write_plotly(self.fig_labels, self.fig_labels_json_fid)
678
+ else:
679
+ if not self.live:
680
+ self.prepare_labels()
681
+ if save:
682
+ # save extracted label instances
683
+ self.label_dset.save_to_disk(self.label_dset_fid)
684
+ write_plotly(self.fig_labels, self.fig_labels_json_fid)
685
+
686
+ def prepare_labels(self):
687
+ if not self.live:
688
+ self.get_base_dataset()
689
+ self.label_dset = self.dset.map(
690
+ lambda examples: extract_field(
691
+ examples, self.label_field, OUR_LABEL_FIELD
692
+ ),
693
+ batched=True,
694
+ remove_columns=list(self.dset.features),
695
+ )
696
+ self.label_df = self.label_dset.to_pandas()
697
+ self.fig_labels = make_fig_labels(
698
+ self.label_df, self.label_names, OUR_LABEL_FIELD
699
+ )
700
+
701
+ def load_or_prepare_npmi(self):
702
+ self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
703
+ self.npmi_stats.load_or_prepare_npmi_terms()
704
+
705
+ def load_or_prepare_zipf(self, save=True):
706
+ # TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
707
+ # when only reading from cache. Either the UI should use it, or it should
708
+ # be removed when reading in cache
709
+ if self.use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
710
+ with open(self.zipf_fid, "r") as f:
711
+ zipf_dict = json.load(f)
712
+ self.z = Zipf()
713
+ self.z.load(zipf_dict)
714
+ # TODO: Should this be cached?
715
+ self.zipf_counts = self.z.calc_zipf_counts(self.vocab_counts_df)
716
+ self.zipf_fig = read_plotly(self.zipf_fig_fid)
717
+ elif self.use_cache and exists(self.zipf_fid):
718
+ # TODO: Read zipf data so that the vocab is there.
719
+ with open(self.zipf_fid, "r") as f:
720
+ zipf_dict = json.load(f)
721
+ self.z = Zipf()
722
+ self.z.load(zipf_dict)
723
+ self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
724
+ if save:
725
+ write_plotly(self.zipf_fig, self.zipf_fig_fid)
726
+ else:
727
+ self.z = Zipf(self.vocab_counts_df)
728
+ self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
729
+ if save:
730
+ write_zipf_data(self.z, self.zipf_fid)
731
+ write_plotly(self.zipf_fig, self.zipf_fig_fid)
732
+
733
+ def _set_idx_col_names(self, input_vocab_df):
734
+ if input_vocab_df.index.name != VOCAB and VOCAB in input_vocab_df.columns:
735
+ input_vocab_df = input_vocab_df.set_index([VOCAB])
736
+ input_vocab_df[VOCAB] = input_vocab_df.index
737
+ return input_vocab_df
738
+
739
+
740
+ class nPMIStatisticsCacheClass:
741
+ """ "Class to interface between the app and the nPMI class
742
+ by calling the nPMI class with the user's selections."""
743
+
744
+ def __init__(self, dataset_stats, use_cache=False):
745
+ self.live = dataset_stats.live
746
+ self.dstats = dataset_stats
747
+ self.pmi_cache_path = pjoin(self.dstats.cache_path, "pmi_files")
748
+ if not isdir(self.pmi_cache_path):
749
+ logs.warning("Creating pmi cache directory %s." % self.pmi_cache_path)
750
+ # We need to preprocess everything.
751
+ mkdir(self.pmi_cache_path)
752
+ self.joint_npmi_df_dict = {}
753
+ # TODO: Users ideally can type in whatever words they want.
754
+ self.termlist = _IDENTITY_TERMS
755
+ # termlist terms that are available more than _MIN_VOCAB_COUNT times
756
+ self.available_terms = _IDENTITY_TERMS
757
+ logs.info(self.termlist)
758
+ self.use_cache = use_cache
759
+ # TODO: Let users specify
760
+ self.open_class_only = True
761
+ self.min_vocab_count = self.dstats.min_vocab_count
762
+ self.subgroup_files = {}
763
+ self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
764
+
765
+ def load_or_prepare_npmi_terms(self):
766
+ """
767
+ Figures out what identity terms the user can select, based on whether
768
+ they occur more than self.min_vocab_count times
769
+ :return: Identity terms occurring at least self.min_vocab_count times.
770
+ """
771
+ # TODO: Add the user's ability to select subgroups.
772
+ # TODO: Make min_vocab_count here value selectable by the user.
773
+ if (
774
+ self.use_cache
775
+ and exists(self.npmi_terms_fid)
776
+ and json.load(open(self.npmi_terms_fid))["available terms"] != []
777
+ ):
778
+ self.available_terms = json.load(open(self.npmi_terms_fid))["available terms"]
779
+ else:
780
+ if not self.live:
781
+ if self.dstats.vocab_counts_df is None:
782
+ self.dstats.load_or_prepare_vocab()
783
+
784
+ true_false = [
785
+ term in self.dstats.vocab_counts_df.index for term in self.termlist
786
+ ]
787
+ word_list_tmp = [x for x, y in zip(self.termlist, true_false) if y]
788
+ true_false_counts = [
789
+ self.dstats.vocab_counts_df.loc[word, CNT] >= self.min_vocab_count
790
+ for word in word_list_tmp
791
+ ]
792
+ available_terms = [
793
+ word for word, y in zip(word_list_tmp, true_false_counts) if y
794
+ ]
795
+ logs.info(available_terms)
796
+ with open(self.npmi_terms_fid, "w+") as f:
797
+ json.dump({"available terms": available_terms}, f)
798
+ self.available_terms = available_terms
799
+ return self.available_terms
800
+
801
+ def load_or_prepare_joint_npmi(self, subgroup_pair, save=True):
802
+ """
803
+ Run on-the fly, while the app is already open,
804
+ as it depends on the subgroup terms that the user chooses
805
+ :param subgroup_pair:
806
+ :return:
807
+ """
808
+ # Canonical ordering for subgroup_list
809
+ subgroup_pair = sorted(subgroup_pair)
810
+ subgroup1 = subgroup_pair[0]
811
+ subgroup2 = subgroup_pair[1]
812
+ subgroups_str = "-".join(subgroup_pair)
813
+ if not isdir(self.pmi_cache_path):
814
+ logs.warning("Creating cache")
815
+ # We need to preprocess everything.
816
+ # This should eventually all go into a prepare_dataset CLI
817
+ mkdir(self.pmi_cache_path)
818
+ joint_npmi_fid = pjoin(self.pmi_cache_path, subgroups_str + "_npmi.csv")
819
+ subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
820
+ # Defines the filenames for the cache files from the selected subgroups.
821
+ # Get as much precomputed data as we can.
822
+ if self.use_cache and exists(joint_npmi_fid):
823
+ # When everything is already computed for the selected subgroups.
824
+ logs.info("Loading cached joint npmi")
825
+ joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
826
+ npmi_display_cols = ['npmi-bias', subgroup1 + '-npmi', subgroup2 + '-npmi', subgroup1 + '-count', subgroup2 + '-count']
827
+ joint_npmi_df = joint_npmi_df[npmi_display_cols]
828
+ # When maybe some things have been computed for the selected subgroups.
829
+ else:
830
+ if not self.live:
831
+ logs.info("Preparing new joint npmi")
832
+ joint_npmi_df, subgroup_dict = self.prepare_joint_npmi_df(
833
+ subgroup_pair, subgroup_files
834
+ )
835
+ if save:
836
+ if joint_npmi_df is not None:
837
+ # Cache new results
838
+ logs.info("Writing out.")
839
+ for subgroup in subgroup_pair:
840
+ write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
841
+ with open(joint_npmi_fid, "w+") as f:
842
+ joint_npmi_df.to_csv(f)
843
+ else:
844
+ joint_npmi_df = pd.DataFrame()
845
+ logs.info("The joint npmi df is")
846
+ logs.info(joint_npmi_df)
847
+ return joint_npmi_df
848
+
849
+ def load_joint_npmi_df(self, joint_npmi_fid):
850
+ """
851
+ Reads in a saved dataframe with all of the paired results.
852
+ :param joint_npmi_fid:
853
+ :return: paired results
854
+ """
855
+ with open(joint_npmi_fid, "rb") as f:
856
+ joint_npmi_df = pd.read_csv(f)
857
+ joint_npmi_df = self._set_idx_cols_from_cache(joint_npmi_df)
858
+ return joint_npmi_df.dropna()
859
+
860
+ def prepare_joint_npmi_df(self, subgroup_pair, subgroup_files):
861
+ """
862
+ Computs the npmi bias based on the given subgroups.
863
+ Handles cases where some of the selected subgroups have cached nPMI
864
+ computations, but other's don't, computing everything afresh if there
865
+ are not cached files.
866
+ :param subgroup_pair:
867
+ :return: Dataframe with nPMI for the words, nPMI bias between the words.
868
+ """
869
+ subgroup_dict = {}
870
+ # When npmi is computed for some (but not all) of subgroup_list
871
+ for subgroup in subgroup_pair:
872
+ logs.info("Load or failing...")
873
+ # When subgroup npmi has been computed in a prior session.
874
+ cached_results = self.load_or_fail_cached_npmi_scores(
875
+ subgroup, subgroup_files[subgroup]
876
+ )
877
+ # If the function did not return False and we did find it, use.
878
+ if cached_results:
879
+ # FYI: subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = cached_results
880
+ # Holds the previous sessions' data for use in this session.
881
+ subgroup_dict[subgroup] = cached_results
882
+ logs.info("Calculating for subgroup list")
883
+ joint_npmi_df, subgroup_dict = self.do_npmi(subgroup_pair, subgroup_dict)
884
+ return joint_npmi_df, subgroup_dict
885
+
886
+ # TODO: Update pairwise assumption
887
+ def do_npmi(self, subgroup_pair, subgroup_dict):
888
+ """
889
+ Calculates nPMI for given identity terms and the nPMI bias between.
890
+ :param subgroup_pair: List of identity terms to calculate the bias for
891
+ :return: Subset of data for the UI
892
+ :return: Selected identity term's co-occurrence counts with
893
+ other words, pmi per word, and nPMI per word.
894
+ """
895
+ no_results = False
896
+ logs.info("Initializing npmi class")
897
+ npmi_obj = self.set_npmi_obj()
898
+ # Canonical ordering used
899
+ subgroup_pair = tuple(sorted(subgroup_pair))
900
+ # Calculating nPMI statistics
901
+ for subgroup in subgroup_pair:
902
+ # If the subgroup data is already computed, grab it.
903
+ # TODO: Should we set idx and column names similarly to
904
+ # how we set them for cached files?
905
+ if subgroup not in subgroup_dict:
906
+ logs.info("Calculating statistics for %s" % subgroup)
907
+ vocab_cooc_df, pmi_df, npmi_df = npmi_obj.calc_metrics(subgroup)
908
+ if vocab_cooc_df is None:
909
+ no_results = True
910
+ else:
911
+ # Store the nPMI information for the current subgroups
912
+ subgroup_dict[subgroup] = (vocab_cooc_df, pmi_df, npmi_df)
913
+ if no_results:
914
+ logs.warning("Couldn't grap the npmi files -- Under construction")
915
+ return None, None
916
+ else:
917
+ # Pair the subgroups together, indexed by all words that
918
+ # co-occur between them.
919
+ logs.info("Computing pairwise npmi bias")
920
+ paired_results = npmi_obj.calc_paired_metrics(subgroup_pair, subgroup_dict)
921
+ UI_results = make_npmi_fig(paired_results, subgroup_pair)
922
+ return UI_results.dropna(), subgroup_dict
923
+
924
+ def set_npmi_obj(self):
925
+ """
926
+ Initializes the nPMI class with the given words and tokenized sentences.
927
+ :return:
928
+ """
929
+ npmi_obj = nPMI(self.dstats.vocab_counts_df, self.dstats.tokenized_df)
930
+ return npmi_obj
931
+
932
+ def load_or_fail_cached_npmi_scores(self, subgroup, subgroup_fids):
933
+ """
934
+ Reads cached scores from the specified subgroup files
935
+ :param subgroup: string of the selected identity term
936
+ :return:
937
+ """
938
+ # TODO: Ordering of npmi, pmi, vocab triple should be consistent
939
+ subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
940
+ if (
941
+ exists(subgroup_npmi_fid)
942
+ and exists(subgroup_pmi_fid)
943
+ and exists(subgroup_cooc_fid)
944
+ ):
945
+ logs.info("Reading in pmi data....")
946
+ with open(subgroup_cooc_fid, "rb") as f:
947
+ subgroup_cooc_df = pd.read_csv(f)
948
+ logs.info("pmi")
949
+ with open(subgroup_pmi_fid, "rb") as f:
950
+ subgroup_pmi_df = pd.read_csv(f)
951
+ logs.info("npmi")
952
+ with open(subgroup_npmi_fid, "rb") as f:
953
+ subgroup_npmi_df = pd.read_csv(f)
954
+ subgroup_cooc_df = self._set_idx_cols_from_cache(
955
+ subgroup_cooc_df, subgroup, "count"
956
+ )
957
+ subgroup_pmi_df = self._set_idx_cols_from_cache(
958
+ subgroup_pmi_df, subgroup, "pmi"
959
+ )
960
+ subgroup_npmi_df = self._set_idx_cols_from_cache(
961
+ subgroup_npmi_df, subgroup, "npmi"
962
+ )
963
+ return subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df
964
+ return False
965
+
966
+ def _set_idx_cols_from_cache(self, csv_df, subgroup=None, calc_str=None):
967
+ """
968
+ Helps make sure all of the read-in files can be accessed within code
969
+ via standardized indices and column names.
970
+ :param csv_df:
971
+ :param subgroup:
972
+ :param calc_str:
973
+ :return:
974
+ """
975
+ # The csv saves with this column instead of the index, so that's weird.
976
+ if "Unnamed: 0" in csv_df.columns:
977
+ csv_df = csv_df.set_index("Unnamed: 0")
978
+ csv_df.index.name = WORD
979
+ elif WORD in csv_df.columns:
980
+ csv_df = csv_df.set_index(WORD)
981
+ csv_df.index.name = WORD
982
+ elif VOCAB in csv_df.columns:
983
+ csv_df = csv_df.set_index(VOCAB)
984
+ csv_df.index.name = WORD
985
+ if subgroup and calc_str:
986
+ csv_df.columns = [subgroup + "-" + calc_str]
987
+ elif subgroup:
988
+ csv_df.columns = [subgroup]
989
+ elif calc_str:
990
+ csv_df.columns = [calc_str]
991
+ return csv_df
992
+
993
+ def get_available_terms(self):
994
+ return self.load_or_prepare_npmi_terms()
995
+
996
+ def dummy(doc):
997
+ return doc
998
+
999
+ def count_vocab_frequencies(tokenized_df):
1000
+ """
1001
+ Based on an input pandas DataFrame with a 'text' column,
1002
+ this function will count the occurrences of all words.
1003
+ :return: [num_words x num_sentences] DataFrame with the rows corresponding to the
1004
+ different vocabulary words and the column to the presence (0 or 1) of that word.
1005
+ """
1006
+
1007
+ cvec = CountVectorizer(
1008
+ tokenizer=dummy,
1009
+ preprocessor=dummy,
1010
+ )
1011
+ # We do this to calculate per-word statistics
1012
+ # Fast calculation of single word counts
1013
+ logs.info("Fitting dummy tokenization to make matrix using the previous tokenization")
1014
+ cvec.fit(tokenized_df[TOKENIZED_FIELD])
1015
+ document_matrix = cvec.transform(tokenized_df[TOKENIZED_FIELD])
1016
+ batches = np.linspace(0, tokenized_df.shape[0], _NUM_VOCAB_BATCHES).astype(int)
1017
+ i = 0
1018
+ tf = []
1019
+ while i < len(batches) - 1:
1020
+ logs.info("%s of %s vocab batches" % (str(i), str(len(batches))))
1021
+ batch_result = np.sum(
1022
+ document_matrix[batches[i] : batches[i + 1]].toarray(), axis=0
1023
+ )
1024
+ tf.append(batch_result)
1025
+ i += 1
1026
+ word_count_df = pd.DataFrame(
1027
+ [np.sum(tf, axis=0)], columns=cvec.get_feature_names()
1028
+ ).transpose()
1029
+ # Now organize everything into the dataframes
1030
+ word_count_df.columns = [CNT]
1031
+ word_count_df.index.name = WORD
1032
+ return word_count_df
1033
+
1034
+ def calc_p_word(word_count_df):
1035
+ # p(word)
1036
+ word_count_df[PROP] = word_count_df[CNT] / float(sum(word_count_df[CNT]))
1037
+ vocab_counts_df = pd.DataFrame(word_count_df.sort_values(by=CNT, ascending=False))
1038
+ vocab_counts_df[VOCAB] = vocab_counts_df.index
1039
+ return vocab_counts_df
1040
+
1041
+
1042
+ def filter_vocab(vocab_counts_df):
1043
+ # TODO: Add warnings (which words are missing) to log file?
1044
+ filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
1045
+ errors="ignore")
1046
+ filtered_count = filtered_vocab_counts_df[CNT]
1047
+ filtered_count_denom = float(sum(filtered_vocab_counts_df[CNT]))
1048
+ filtered_vocab_counts_df[PROP] = filtered_count / filtered_count_denom
1049
+ return filtered_vocab_counts_df
1050
+
1051
+
1052
+ ## Figures ##
1053
+
1054
+ def write_plotly(fig, fid):
1055
+ write_json(plotly.io.to_json(fig), fid)
1056
+
1057
+ def read_plotly(fid):
1058
+ fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8")))
1059
+ return fig
1060
+
1061
+ def make_fig_lengths(tokenized_df, length_field):
1062
+ fig_tok_length = px.histogram(
1063
+ tokenized_df, x=length_field, marginal="rug", hover_data=[length_field]
1064
+ )
1065
+ return fig_tok_length
1066
+
1067
+ def make_fig_labels(label_df, label_names, label_field):
1068
+ labels = label_df[label_field].unique()
1069
+ label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
1070
+ fig_labels = px.pie(label_df, values=label_sums, names=label_names)
1071
+ return fig_labels
1072
+
1073
+
1074
+ def make_zipf_fig_ranked_word_list(vocab_df, unique_counts, unique_ranks):
1075
+ ranked_words = {}
1076
+ for count, rank in zip(unique_counts, unique_ranks):
1077
+ vocab_df[vocab_df[CNT] == count]["rank"] = rank
1078
+ ranked_words[rank] = ",".join(
1079
+ vocab_df[vocab_df[CNT] == count].index.astype(str)
1080
+ ) # Use the hovertext kw argument for hover text
1081
+ ranked_words_list = [wrds for rank, wrds in sorted(ranked_words.items())]
1082
+ return ranked_words_list
1083
+
1084
+
1085
+ def make_npmi_fig(paired_results, subgroup_pair):
1086
+ subgroup1, subgroup2 = subgroup_pair
1087
+ UI_results = pd.DataFrame()
1088
+ if "npmi-bias" in paired_results:
1089
+ UI_results["npmi-bias"] = paired_results["npmi-bias"].astype(float)
1090
+ UI_results[subgroup1 + "-npmi"] = paired_results["npmi"][
1091
+ subgroup1 + "-npmi"
1092
+ ].astype(float)
1093
+ UI_results[subgroup1 + "-count"] = paired_results["count"][
1094
+ subgroup1 + "-count"
1095
+ ].astype(int)
1096
+ if subgroup1 != subgroup2:
1097
+ UI_results[subgroup2 + "-npmi"] = paired_results["npmi"][
1098
+ subgroup2 + "-npmi"
1099
+ ].astype(float)
1100
+ UI_results[subgroup2 + "-count"] = paired_results["count"][
1101
+ subgroup2 + "-count"
1102
+ ].astype(int)
1103
+ return UI_results.sort_values(by="npmi-bias", ascending=True)
1104
+
1105
+
1106
+ def make_zipf_fig(vocab_counts_df, z):
1107
+ zipf_counts = z.calc_zipf_counts(vocab_counts_df)
1108
+ unique_counts = z.uniq_counts
1109
+ unique_ranks = z.uniq_ranks
1110
+ ranked_words_list = make_zipf_fig_ranked_word_list(
1111
+ vocab_counts_df, unique_counts, unique_ranks
1112
+ )
1113
+ zmin = z.get_xmin()
1114
+ logs.info("zipf counts is")
1115
+ logs.info(zipf_counts)
1116
+ layout = go.Layout(xaxis=dict(range=[0, 100]))
1117
+ fig = go.Figure(
1118
+ data=[
1119
+ go.Bar(
1120
+ x=z.uniq_ranks,
1121
+ y=z.uniq_counts,
1122
+ hovertext=ranked_words_list,
1123
+ name="Word Rank Frequency",
1124
+ )
1125
+ ],
1126
+ layout=layout,
1127
+ )
1128
+ fig.add_trace(
1129
+ go.Scatter(
1130
+ x=z.uniq_ranks[zmin : len(z.uniq_ranks)],
1131
+ y=zipf_counts[zmin : len(z.uniq_ranks)],
1132
+ hovertext=ranked_words_list[zmin : len(z.uniq_ranks)],
1133
+ line=go.scatter.Line(color="crimson", width=3),
1134
+ name="Zipf Predicted Frequency",
1135
+ )
1136
+ )
1137
+ # Customize aspect
1138
+ # fig.update_traces(marker_color='limegreen',
1139
+ # marker_line_width=1.5, opacity=0.6)
1140
+ fig.update_layout(title_text="Word Counts, Observed and Predicted by Zipf")
1141
+ fig.update_layout(xaxis_title="Word Rank")
1142
+ fig.update_layout(yaxis_title="Frequency")
1143
+ fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.10))
1144
+ return fig
1145
+
1146
+
1147
+ def make_tree_plot(node_list, text_dset):
1148
+ nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
1149
+
1150
+ for nid, node in enumerate(node_list):
1151
+ node["label"] = node.get(
1152
+ "label",
1153
+ f"{nid:2d} - {node['weight']:5d} items <br>"
1154
+ + "<br>".join(
1155
+ [
1156
+ "> " + txt[:64] + ("..." if len(txt) >= 63 else "")
1157
+ for txt in list(
1158
+ set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
1159
+ )[:5]
1160
+ ]
1161
+ ),
1162
+ )
1163
+
1164
+ # make plot nodes
1165
+ # TODO: something more efficient than set to remove duplicates
1166
+ labels = [node["label"] for node in node_list]
1167
+
1168
+ root = node_list[0]
1169
+ root["X"] = 0
1170
+ root["Y"] = 0
1171
+
1172
+ def rec_make_coordinates(node):
1173
+ total_weight = 0
1174
+ add_weight = len(node["example_ids"]) - sum(
1175
+ [child["weight"] for child in node["children"]]
1176
+ )
1177
+ for child in node["children"]:
1178
+ child["X"] = node["X"] + total_weight
1179
+ child["Y"] = node["Y"] - 1
1180
+ total_weight += child["weight"] + add_weight / len(node["children"])
1181
+ rec_make_coordinates(child)
1182
+
1183
+ rec_make_coordinates(root)
1184
+
1185
+ E = [] # list of edges
1186
+ Xn = []
1187
+ Yn = []
1188
+ Xe = []
1189
+ Ye = []
1190
+ for nid, node in enumerate(node_list):
1191
+ Xn += [node["X"]]
1192
+ Yn += [node["Y"]]
1193
+ for child in node["children"]:
1194
+ E += [(nid, nid_map[child["nid"]])]
1195
+ Xe += [node["X"], child["X"], None]
1196
+ Ye += [node["Y"], child["Y"], None]
1197
+
1198
+ # make figure
1199
+ fig = go.Figure()
1200
+ fig.add_trace(
1201
+ go.Scatter(
1202
+ x=Xe,
1203
+ y=Ye,
1204
+ mode="lines",
1205
+ line=dict(color="rgb(210,210,210)", width=1),
1206
+ hoverinfo="none",
1207
+ )
1208
+ )
1209
+ fig.add_trace(
1210
+ go.Scatter(
1211
+ x=Xn,
1212
+ y=Yn,
1213
+ mode="markers",
1214
+ name="nodes",
1215
+ marker=dict(
1216
+ symbol="circle-dot",
1217
+ size=18,
1218
+ color="#6175c1",
1219
+ line=dict(color="rgb(50,50,50)", width=1)
1220
+ # '#DB4551',
1221
+ ),
1222
+ text=labels,
1223
+ hoverinfo="text",
1224
+ opacity=0.8,
1225
+ )
1226
+ )
1227
+ return fig
1228
+
1229
+
1230
+ ## Input/Output ###
1231
+
1232
+
1233
+ def define_subgroup_files(subgroup_list, pmi_cache_path):
1234
+ """
1235
+ Sets the file ids for the input identity terms
1236
+ :param subgroup_list: List of identity terms
1237
+ :return:
1238
+ """
1239
+ subgroup_files = {}
1240
+ for subgroup in subgroup_list:
1241
+ # TODO: Should the pmi, npmi, and count just be one file?
1242
+ subgroup_npmi_fid = pjoin(pmi_cache_path, subgroup + "_npmi.csv")
1243
+ subgroup_pmi_fid = pjoin(pmi_cache_path, subgroup + "_pmi.csv")
1244
+ subgroup_cooc_fid = pjoin(pmi_cache_path, subgroup + "_vocab_cooc.csv")
1245
+ subgroup_files[subgroup] = (
1246
+ subgroup_npmi_fid,
1247
+ subgroup_pmi_fid,
1248
+ subgroup_cooc_fid,
1249
+ )
1250
+ return subgroup_files
1251
+
1252
+
1253
+ ## Input/Output ##
1254
+
1255
+
1256
+ def intersect_dfs(df_dict):
1257
+ started = 0
1258
+ new_df = None
1259
+ for key, df in df_dict.items():
1260
+ if df is None:
1261
+ continue
1262
+ for key2, df2 in df_dict.items():
1263
+ if df2 is None:
1264
+ continue
1265
+ if key == key2:
1266
+ continue
1267
+ if started:
1268
+ new_df = new_df.join(df2, how="inner", lsuffix="1", rsuffix="2")
1269
+ else:
1270
+ new_df = df.join(df2, how="inner", lsuffix="1", rsuffix="2")
1271
+ started = 1
1272
+ return new_df.copy()
1273
+
1274
+
1275
+ def write_df(df, df_fid):
1276
+ feather.write_feather(df, df_fid)
1277
+
1278
+
1279
+ def write_json(json_dict, json_fid):
1280
+ with open(json_fid, "w", encoding="utf-8") as f:
1281
+ json.dump(json_dict, f)
1282
+
1283
+ def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
1284
+ """
1285
+ Saves the calculated nPMI statistics to their output files.
1286
+ Includes the npmi scores for each identity term, the pmi scores, and the
1287
+ co-occurrence counts of the identity term with all the other words
1288
+ :param subgroup: Identity term
1289
+ :return:
1290
+ """
1291
+ subgroup_fids = subgroup_files[subgroup]
1292
+ subgroup_npmi_fid, subgroup_pmi_fid, subgroup_cooc_fid = subgroup_fids
1293
+ subgroup_dfs = subgroup_dict[subgroup]
1294
+ subgroup_cooc_df, subgroup_pmi_df, subgroup_npmi_df = subgroup_dfs
1295
+ with open(subgroup_npmi_fid, "w+") as f:
1296
+ subgroup_npmi_df.to_csv(f)
1297
+ with open(subgroup_pmi_fid, "w+") as f:
1298
+ subgroup_pmi_df.to_csv(f)
1299
+ with open(subgroup_cooc_fid, "w+") as f:
1300
+ subgroup_cooc_df.to_csv(f)
1301
+
1302
+ def write_zipf_data(z, zipf_fid):
1303
+ zipf_dict = {}
1304
+ zipf_dict["xmin"] = int(z.xmin)
1305
+ zipf_dict["xmax"] = int(z.xmax)
1306
+ zipf_dict["alpha"] = float(z.alpha)
1307
+ zipf_dict["ks_distance"] = float(z.distance)
1308
+ zipf_dict["p-value"] = float(z.ks_test.pvalue)
1309
+ zipf_dict["uniq_counts"] = [int(count) for count in z.uniq_counts]
1310
+ zipf_dict["uniq_ranks"] = [int(rank) for rank in z.uniq_ranks]
1311
+ with open(zipf_fid, "w+", encoding="utf-8") as f:
1312
+ json.dump(zipf_dict, f)
1313
+