Spaces:
Build error
Build error
meg-huggingface
commited on
Commit
•
4b53042
1
Parent(s):
66693d5
Begins modularizing so that each widget can be independently loaded without having a requirement on the ordering of load_or_preparing in app.py. This means that each function corresponding to a widget will check if the variables it depends on have been calculated yet. If not, it will call back to calculate them. Because of the messiness this causes with passing the use_cache variable around, I've now set use_cache as a global variable, set when the DatasetStatisticsCacheClass is initialized, and removed the use_cache arguments appearing in nearly every function.
Browse files- app.py +45 -12
- data_measurements/dataset_statistics.py +70 -53
- data_measurements/streamlit_utils.py +3 -3
app.py
CHANGED
@@ -100,30 +100,63 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
|
|
100 |
mkdir(CACHE_DIR)
|
101 |
if use_cache:
|
102 |
logs.warning("Using cache")
|
103 |
-
dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args)
|
104 |
logs.warning("Loading Dataset")
|
105 |
-
dstats.load_or_prepare_dataset(
|
106 |
logs.warning("Extracting Labels")
|
107 |
-
dstats.load_or_prepare_labels(
|
108 |
logs.warning("Computing Text Lengths")
|
109 |
-
dstats.load_or_prepare_text_lengths(
|
|
|
|
|
110 |
logs.warning("Extracting Vocabulary")
|
111 |
-
dstats.load_or_prepare_vocab(
|
112 |
logs.warning("Calculating General Statistics...")
|
113 |
-
dstats.load_or_prepare_general_stats(
|
114 |
logs.warning("Completed Calculation.")
|
115 |
logs.warning("Calculating Fine-Grained Statistics...")
|
116 |
if show_embeddings:
|
117 |
logs.warning("Loading Embeddings")
|
118 |
-
dstats.load_or_prepare_embeddings(
|
119 |
print(dstats.fig_tree)
|
120 |
# TODO: This has now been moved to calculation when the npmi widget is loaded.
|
121 |
logs.warning("Loading Terms for nPMI")
|
122 |
-
dstats.load_or_prepare_npmi_terms(
|
123 |
logs.warning("Loading Zipf")
|
124 |
-
dstats.load_or_prepare_zipf(
|
125 |
return dstats
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
|
129 |
"""
|
@@ -144,7 +177,7 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T
|
|
144 |
st_utils.expander_header(dstats, ds_name_to_dict, column_id)
|
145 |
logs.info("showing general stats")
|
146 |
st_utils.expander_general_stats(dstats, column_id)
|
147 |
-
st_utils.expander_label_distribution(dstats.
|
148 |
st_utils.expander_text_lengths(
|
149 |
dstats.tokenized_df,
|
150 |
dstats.fig_tok_length,
|
@@ -163,7 +196,7 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T
|
|
163 |
npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
|
164 |
dstats, use_cache=use_cache
|
165 |
)
|
166 |
-
available_terms = npmi_stats.get_available_terms(
|
167 |
st_utils.npmi_widget(
|
168 |
column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT, use_cache=use_cache
|
169 |
)
|
@@ -190,7 +223,7 @@ def main():
|
|
190 |
compare_mode = st.sidebar.checkbox("Comparison mode")
|
191 |
|
192 |
# When not doing new development, use the cache.
|
193 |
-
use_cache =
|
194 |
show_embeddings = st.sidebar.checkbox("Show embeddings")
|
195 |
# List of datasets for which embeddings are hard to compute:
|
196 |
|
|
|
100 |
mkdir(CACHE_DIR)
|
101 |
if use_cache:
|
102 |
logs.warning("Using cache")
|
103 |
+
dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
|
104 |
logs.warning("Loading Dataset")
|
105 |
+
dstats.load_or_prepare_dataset()
|
106 |
logs.warning("Extracting Labels")
|
107 |
+
dstats.load_or_prepare_labels()
|
108 |
logs.warning("Computing Text Lengths")
|
109 |
+
dstats.load_or_prepare_text_lengths()
|
110 |
+
logs.warning("Computing Duplicates")
|
111 |
+
dstats.load_or_prepare_text_duplicates()
|
112 |
logs.warning("Extracting Vocabulary")
|
113 |
+
dstats.load_or_prepare_vocab()
|
114 |
logs.warning("Calculating General Statistics...")
|
115 |
+
dstats.load_or_prepare_general_stats()
|
116 |
logs.warning("Completed Calculation.")
|
117 |
logs.warning("Calculating Fine-Grained Statistics...")
|
118 |
if show_embeddings:
|
119 |
logs.warning("Loading Embeddings")
|
120 |
+
dstats.load_or_prepare_embeddings()
|
121 |
print(dstats.fig_tree)
|
122 |
# TODO: This has now been moved to calculation when the npmi widget is loaded.
|
123 |
logs.warning("Loading Terms for nPMI")
|
124 |
+
dstats.load_or_prepare_npmi_terms()
|
125 |
logs.warning("Loading Zipf")
|
126 |
+
dstats.load_or_prepare_zipf()
|
127 |
return dstats
|
128 |
|
129 |
+
def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
|
130 |
+
"""
|
131 |
+
Loader specifically for the widgets used in the app.
|
132 |
+
Args:
|
133 |
+
ds_args:
|
134 |
+
show_embeddings:
|
135 |
+
use_cache:
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
|
139 |
+
"""
|
140 |
+
if not isdir(CACHE_DIR):
|
141 |
+
logs.warning("Creating cache")
|
142 |
+
# We need to preprocess everything.
|
143 |
+
# This should eventually all go into a prepare_dataset CLI
|
144 |
+
mkdir(CACHE_DIR)
|
145 |
+
if use_cache:
|
146 |
+
logs.warning("Using cache")
|
147 |
+
dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
|
148 |
+
# Header widget
|
149 |
+
dstats.load_or_prepare_dset_peek()
|
150 |
+
# General stats widget
|
151 |
+
dstats.load_or_prepare_general_stats()
|
152 |
+
# Labels widget
|
153 |
+
dstats.load_or_prepare_labels()
|
154 |
+
# Text lengths widget
|
155 |
+
dstats.load_or_prepare_text_lengths()
|
156 |
+
if show_embeddings:
|
157 |
+
# Embeddings widget
|
158 |
+
dstats.load_or_prepare_embeddings()
|
159 |
+
dstats.load_or_prepare_text_duplicates()
|
160 |
|
161 |
def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
|
162 |
"""
|
|
|
177 |
st_utils.expander_header(dstats, ds_name_to_dict, column_id)
|
178 |
logs.info("showing general stats")
|
179 |
st_utils.expander_general_stats(dstats, column_id)
|
180 |
+
st_utils.expander_label_distribution(dstats.fig_labels, column_id)
|
181 |
st_utils.expander_text_lengths(
|
182 |
dstats.tokenized_df,
|
183 |
dstats.fig_tok_length,
|
|
|
196 |
npmi_stats = dataset_statistics.nPMIStatisticsCacheClass(
|
197 |
dstats, use_cache=use_cache
|
198 |
)
|
199 |
+
available_terms = npmi_stats.get_available_terms()
|
200 |
st_utils.npmi_widget(
|
201 |
column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT, use_cache=use_cache
|
202 |
)
|
|
|
223 |
compare_mode = st.sidebar.checkbox("Comparison mode")
|
224 |
|
225 |
# When not doing new development, use the cache.
|
226 |
+
use_cache = False
|
227 |
show_embeddings = st.sidebar.checkbox("Show embeddings")
|
228 |
# List of datasets for which embeddings are hard to compute:
|
229 |
|
data_measurements/dataset_statistics.py
CHANGED
@@ -159,6 +159,7 @@ class DatasetStatisticsCacheClass:
|
|
159 |
label_field,
|
160 |
label_names,
|
161 |
calculation=None,
|
|
|
162 |
):
|
163 |
# This is only used for standalone runs for each kind of measurement.
|
164 |
self.calculation = calculation
|
@@ -168,6 +169,8 @@ class DatasetStatisticsCacheClass:
|
|
168 |
self.our_tokenized_field = TOKENIZED_FIELD
|
169 |
self.our_embedding_field = EMBEDDING_FIELD
|
170 |
self.cache_dir = cache_dir
|
|
|
|
|
171 |
### What are we analyzing?
|
172 |
# name of the Hugging Face dataset
|
173 |
self.dset_name = dset_name
|
@@ -285,20 +288,19 @@ class DatasetStatisticsCacheClass:
|
|
285 |
use_streaming=True,
|
286 |
)
|
287 |
|
288 |
-
def load_or_prepare_general_stats(self,
|
289 |
"""
|
290 |
Content for expander_general_stats widget.
|
291 |
Provides statistics for total words, total open words,
|
292 |
the sorted top vocab, the NaN count, and the duplicate count.
|
293 |
Args:
|
294 |
-
use_cache:
|
295 |
|
296 |
Returns:
|
297 |
|
298 |
"""
|
299 |
# General statistics
|
300 |
if (
|
301 |
-
use_cache
|
302 |
and exists(self.general_stats_fid)
|
303 |
and exists(self.dup_counts_df_fid)
|
304 |
and exists(self.sorted_top_vocab_df_fid)
|
@@ -320,10 +322,10 @@ class DatasetStatisticsCacheClass:
|
|
320 |
write_json(self.general_stats_dict, self.general_stats_fid)
|
321 |
|
322 |
|
323 |
-
def load_or_prepare_text_lengths(self,
|
324 |
# TODO: Everything here can be read from cache; it's in a transitory
|
325 |
# state atm where just the fig is cached. Clean up.
|
326 |
-
if use_cache and exists(self.fig_tok_length_fid):
|
327 |
self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
|
328 |
if self.tokenized_df is None:
|
329 |
self.tokenized_df = self.do_tokenization()
|
@@ -340,18 +342,18 @@ class DatasetStatisticsCacheClass:
|
|
340 |
if save:
|
341 |
write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
|
342 |
|
343 |
-
def load_or_prepare_embeddings(self,
|
344 |
-
if use_cache and exists(self.node_list_fid) and exists(self.fig_tree_fid):
|
345 |
self.node_list = torch.load(self.node_list_fid)
|
346 |
self.fig_tree = read_plotly(self.fig_tree_fid)
|
347 |
-
elif use_cache and exists(self.node_list_fid):
|
348 |
self.node_list = torch.load(self.node_list_fid)
|
349 |
self.fig_tree = make_tree_plot(self.node_list,
|
350 |
self.text_dset)
|
351 |
if save:
|
352 |
write_plotly(self.fig_tree, self.fig_tree_fid)
|
353 |
else:
|
354 |
-
self.embeddings = Embeddings(self, use_cache=use_cache)
|
355 |
self.embeddings.make_hierarchical_clustering()
|
356 |
self.node_list = self.embeddings.node_list
|
357 |
self.fig_tree = make_tree_plot(self.node_list,
|
@@ -361,15 +363,15 @@ class DatasetStatisticsCacheClass:
|
|
361 |
write_plotly(self.fig_tree, self.fig_tree_fid)
|
362 |
|
363 |
# get vocab with word counts
|
364 |
-
def load_or_prepare_vocab(self,
|
365 |
"""
|
366 |
Calculates the vocabulary count from the tokenized text.
|
367 |
The resulting dataframes may be used in nPMI calculations, zipf, etc.
|
368 |
-
:param
|
369 |
:return:
|
370 |
"""
|
371 |
if (
|
372 |
-
use_cache
|
373 |
and exists(self.vocab_counts_df_fid)
|
374 |
):
|
375 |
logs.info("Reading vocab from cache")
|
@@ -400,10 +402,23 @@ class DatasetStatisticsCacheClass:
|
|
400 |
# Handling for changes in how the index is saved.
|
401 |
self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
|
402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
def load_general_stats(self):
|
404 |
self.general_stats_dict = json.load(open(self.general_stats_fid, encoding="utf-8"))
|
405 |
-
with open(self.dup_counts_df_fid, "rb") as f:
|
406 |
-
self.dup_counts_df = feather.read_feather(f)
|
407 |
with open(self.sorted_top_vocab_df_fid, "rb") as f:
|
408 |
self.sorted_top_vocab_df = feather.read_feather(f)
|
409 |
self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
|
@@ -421,20 +436,10 @@ class DatasetStatisticsCacheClass:
|
|
421 |
self.sorted_top_vocab_df = self.vocab_counts_filtered_df.sort_values(
|
422 |
"count", ascending=False
|
423 |
).head(_TOP_N)
|
424 |
-
print('basics')
|
425 |
self.total_words = len(self.vocab_counts_df)
|
426 |
self.total_open_words = len(self.vocab_counts_filtered_df)
|
427 |
self.text_nan_count = int(self.tokenized_df.isnull().sum().sum())
|
428 |
-
|
429 |
-
print('dup df')
|
430 |
-
self.dup_counts_df = pd.DataFrame(
|
431 |
-
dup_df.pivot_table(
|
432 |
-
columns=[OUR_TEXT_FIELD], aggfunc="size"
|
433 |
-
).sort_values(ascending=False),
|
434 |
-
columns=[CNT],
|
435 |
-
)
|
436 |
-
print('deddup df')
|
437 |
-
self.dup_counts_df[OUR_TEXT_FIELD] = self.dup_counts_df.index.copy()
|
438 |
self.dedup_total = sum(self.dup_counts_df[CNT])
|
439 |
self.general_stats_dict = {
|
440 |
TOT_WORDS: self.total_words,
|
@@ -443,28 +448,40 @@ class DatasetStatisticsCacheClass:
|
|
443 |
DEDUP_TOT: self.dedup_total,
|
444 |
}
|
445 |
|
446 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
"""
|
448 |
Prepares the HF datasets and data frames containing the untokenized and
|
449 |
tokenized text as well as the label values.
|
450 |
self.tokenized_df is used further for calculating text lengths,
|
451 |
word counts, etc.
|
452 |
Args:
|
453 |
-
use_cache: Used stored data if there; otherwise calculate afresh
|
454 |
save: Store the calculated data to disk.
|
455 |
|
456 |
Returns:
|
457 |
|
458 |
"""
|
459 |
logs.info("Doing text dset.")
|
460 |
-
self.load_or_prepare_text_dset(
|
461 |
logs.info("Doing tokenized dataframe")
|
462 |
-
self.load_or_prepare_tokenized_df(
|
463 |
logs.info("Doing dataset peek")
|
464 |
-
self.load_or_prepare_dset_peek(save
|
465 |
|
466 |
-
def load_or_prepare_dset_peek(self, save
|
467 |
-
if use_cache and exists(self.dset_peek_fid):
|
468 |
with open(self.dset_peek_fid, "r") as f:
|
469 |
self.dset_peek = json.load(f)["dset peek"]
|
470 |
else:
|
@@ -472,10 +489,10 @@ class DatasetStatisticsCacheClass:
|
|
472 |
self.get_base_dataset()
|
473 |
self.dset_peek = self.dset[:100]
|
474 |
if save:
|
475 |
-
write_json({"
|
476 |
|
477 |
-
def load_or_prepare_tokenized_df(self,
|
478 |
-
if (use_cache and exists(self.tokenized_df_fid)):
|
479 |
self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
|
480 |
else:
|
481 |
# tokenize all text instances
|
@@ -485,8 +502,8 @@ class DatasetStatisticsCacheClass:
|
|
485 |
# save tokenized text
|
486 |
write_df(self.tokenized_df, self.tokenized_df_fid)
|
487 |
|
488 |
-
def load_or_prepare_text_dset(self,
|
489 |
-
if (use_cache and exists(self.text_dset_fid)):
|
490 |
# load extracted text
|
491 |
self.text_dset = load_from_disk(self.text_dset_fid)
|
492 |
logs.warning("Loaded dataset from disk")
|
@@ -515,6 +532,8 @@ class DatasetStatisticsCacheClass:
|
|
515 |
Tokenizes the dataset
|
516 |
:return:
|
517 |
"""
|
|
|
|
|
518 |
sent_tokenizer = self.cvec.build_tokenizer()
|
519 |
|
520 |
def tokenize_batch(examples):
|
@@ -544,19 +563,18 @@ class DatasetStatisticsCacheClass:
|
|
544 |
"""
|
545 |
self.label_field = label_field
|
546 |
|
547 |
-
def load_or_prepare_labels(self,
|
548 |
# TODO: This is in a transitory state for creating fig cache.
|
549 |
# Clean up to be caching and reading everything correctly.
|
550 |
"""
|
551 |
Extracts labels from the Dataset
|
552 |
-
:param use_cache:
|
553 |
:return:
|
554 |
"""
|
555 |
# extracted labels
|
556 |
if len(self.label_field) > 0:
|
557 |
-
if use_cache and exists(self.fig_labels_fid):
|
558 |
self.fig_labels = read_plotly(self.fig_labels_fid)
|
559 |
-
elif use_cache and exists(self.label_dset_fid):
|
560 |
# load extracted labels
|
561 |
self.label_dset = load_from_disk(self.label_dset_fid)
|
562 |
self.label_df = self.label_dset.to_pandas()
|
@@ -583,21 +601,21 @@ class DatasetStatisticsCacheClass:
|
|
583 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
584 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
585 |
|
586 |
-
def load_or_prepare_npmi_terms(self
|
587 |
-
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=use_cache)
|
588 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
589 |
|
590 |
-
def load_or_prepare_zipf(self,
|
591 |
# TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
|
592 |
# when only reading from cache. Either the UI should use it, or it should
|
593 |
# be removed when reading in cache
|
594 |
-
if use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
|
595 |
with open(self.zipf_fid, "r") as f:
|
596 |
zipf_dict = json.load(f)
|
597 |
self.z = Zipf()
|
598 |
self.z.load(zipf_dict)
|
599 |
self.zipf_fig = read_plotly(self.zipf_fig_fid)
|
600 |
-
elif use_cache and exists(self.zipf_fid):
|
601 |
# TODO: Read zipf data so that the vocab is there.
|
602 |
with open(self.zipf_fid, "r") as f:
|
603 |
zipf_dict = json.load(f)
|
@@ -643,17 +661,16 @@ class nPMIStatisticsCacheClass:
|
|
643 |
self.available_terms = self.dstats.available_terms
|
644 |
logs.info(self.available_terms)
|
645 |
|
646 |
-
def load_or_prepare_npmi_terms(self
|
647 |
"""
|
648 |
Figures out what identity terms the user can select, based on whether
|
649 |
they occur more than self.min_vocab_count times
|
650 |
-
:param use_cache:
|
651 |
:return: Identity terms occurring at least self.min_vocab_count times.
|
652 |
"""
|
653 |
# TODO: Add the user's ability to select subgroups.
|
654 |
# TODO: Make min_vocab_count here value selectable by the user.
|
655 |
if (
|
656 |
-
use_cache
|
657 |
and exists(self.npmi_terms_fid)
|
658 |
and json.load(open(self.npmi_terms_fid))["available terms"] != []
|
659 |
):
|
@@ -676,7 +693,7 @@ class nPMIStatisticsCacheClass:
|
|
676 |
self.available_terms = available_terms
|
677 |
return available_terms
|
678 |
|
679 |
-
def load_or_prepare_joint_npmi(self, subgroup_pair
|
680 |
"""
|
681 |
Run on-the fly, while the app is already open,
|
682 |
as it depends on the subgroup terms that the user chooses
|
@@ -695,7 +712,7 @@ class nPMIStatisticsCacheClass:
|
|
695 |
subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
|
696 |
# Defines the filenames for the cache files from the selected subgroups.
|
697 |
# Get as much precomputed data as we can.
|
698 |
-
if use_cache and exists(joint_npmi_fid):
|
699 |
# When everything is already computed for the selected subgroups.
|
700 |
logs.info("Loading cached joint npmi")
|
701 |
joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
|
@@ -850,8 +867,8 @@ class nPMIStatisticsCacheClass:
|
|
850 |
csv_df.columns = [calc_str]
|
851 |
return csv_df
|
852 |
|
853 |
-
def get_available_terms(self
|
854 |
-
return self.load_or_prepare_npmi_terms(
|
855 |
|
856 |
def dummy(doc):
|
857 |
return doc
|
|
|
159 |
label_field,
|
160 |
label_names,
|
161 |
calculation=None,
|
162 |
+
use_cache=False,
|
163 |
):
|
164 |
# This is only used for standalone runs for each kind of measurement.
|
165 |
self.calculation = calculation
|
|
|
169 |
self.our_tokenized_field = TOKENIZED_FIELD
|
170 |
self.our_embedding_field = EMBEDDING_FIELD
|
171 |
self.cache_dir = cache_dir
|
172 |
+
# Use stored data if there; otherwise calculate afresh
|
173 |
+
self.use_cache = use_cache
|
174 |
### What are we analyzing?
|
175 |
# name of the Hugging Face dataset
|
176 |
self.dset_name = dset_name
|
|
|
288 |
use_streaming=True,
|
289 |
)
|
290 |
|
291 |
+
def load_or_prepare_general_stats(self, save=True):
|
292 |
"""
|
293 |
Content for expander_general_stats widget.
|
294 |
Provides statistics for total words, total open words,
|
295 |
the sorted top vocab, the NaN count, and the duplicate count.
|
296 |
Args:
|
|
|
297 |
|
298 |
Returns:
|
299 |
|
300 |
"""
|
301 |
# General statistics
|
302 |
if (
|
303 |
+
self.use_cache
|
304 |
and exists(self.general_stats_fid)
|
305 |
and exists(self.dup_counts_df_fid)
|
306 |
and exists(self.sorted_top_vocab_df_fid)
|
|
|
322 |
write_json(self.general_stats_dict, self.general_stats_fid)
|
323 |
|
324 |
|
325 |
+
def load_or_prepare_text_lengths(self, save=True):
|
326 |
# TODO: Everything here can be read from cache; it's in a transitory
|
327 |
# state atm where just the fig is cached. Clean up.
|
328 |
+
if self.use_cache and exists(self.fig_tok_length_fid):
|
329 |
self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
|
330 |
if self.tokenized_df is None:
|
331 |
self.tokenized_df = self.do_tokenization()
|
|
|
342 |
if save:
|
343 |
write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
|
344 |
|
345 |
+
def load_or_prepare_embeddings(self, save=True):
|
346 |
+
if self.use_cache and exists(self.node_list_fid) and exists(self.fig_tree_fid):
|
347 |
self.node_list = torch.load(self.node_list_fid)
|
348 |
self.fig_tree = read_plotly(self.fig_tree_fid)
|
349 |
+
elif self.use_cache and exists(self.node_list_fid):
|
350 |
self.node_list = torch.load(self.node_list_fid)
|
351 |
self.fig_tree = make_tree_plot(self.node_list,
|
352 |
self.text_dset)
|
353 |
if save:
|
354 |
write_plotly(self.fig_tree, self.fig_tree_fid)
|
355 |
else:
|
356 |
+
self.embeddings = Embeddings(self, use_cache=self.use_cache)
|
357 |
self.embeddings.make_hierarchical_clustering()
|
358 |
self.node_list = self.embeddings.node_list
|
359 |
self.fig_tree = make_tree_plot(self.node_list,
|
|
|
363 |
write_plotly(self.fig_tree, self.fig_tree_fid)
|
364 |
|
365 |
# get vocab with word counts
|
366 |
+
def load_or_prepare_vocab(self, save=True):
|
367 |
"""
|
368 |
Calculates the vocabulary count from the tokenized text.
|
369 |
The resulting dataframes may be used in nPMI calculations, zipf, etc.
|
370 |
+
:param
|
371 |
:return:
|
372 |
"""
|
373 |
if (
|
374 |
+
self.use_cache
|
375 |
and exists(self.vocab_counts_df_fid)
|
376 |
):
|
377 |
logs.info("Reading vocab from cache")
|
|
|
402 |
# Handling for changes in how the index is saved.
|
403 |
self.vocab_counts_df = self._set_idx_col_names(self.vocab_counts_df)
|
404 |
|
405 |
+
def load_or_prepare_text_duplicates(self, save=True):
|
406 |
+
if self.use_cache and exists(self.dup_counts_df_fid):
|
407 |
+
with open(self.dup_counts_df_fid, "rb") as f:
|
408 |
+
self.dup_counts_df = feather.read_feather(f)
|
409 |
+
elif self.dup_counts_df is None:
|
410 |
+
self.prepare_text_duplicates()
|
411 |
+
if save:
|
412 |
+
write_df(self.dup_counts_df, self.dup_counts_df_fid)
|
413 |
+
else:
|
414 |
+
# This happens when self.dup_counts_df is already defined;
|
415 |
+
# This happens when general_statistics were calculated first,
|
416 |
+
# since general statistics requires the number of duplicates
|
417 |
+
if save:
|
418 |
+
write_df(self.dup_counts_df, self.dup_counts_df_fid)
|
419 |
+
|
420 |
def load_general_stats(self):
|
421 |
self.general_stats_dict = json.load(open(self.general_stats_fid, encoding="utf-8"))
|
|
|
|
|
422 |
with open(self.sorted_top_vocab_df_fid, "rb") as f:
|
423 |
self.sorted_top_vocab_df = feather.read_feather(f)
|
424 |
self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
|
|
|
436 |
self.sorted_top_vocab_df = self.vocab_counts_filtered_df.sort_values(
|
437 |
"count", ascending=False
|
438 |
).head(_TOP_N)
|
|
|
439 |
self.total_words = len(self.vocab_counts_df)
|
440 |
self.total_open_words = len(self.vocab_counts_filtered_df)
|
441 |
self.text_nan_count = int(self.tokenized_df.isnull().sum().sum())
|
442 |
+
self.prepare_text_duplicates()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
self.dedup_total = sum(self.dup_counts_df[CNT])
|
444 |
self.general_stats_dict = {
|
445 |
TOT_WORDS: self.total_words,
|
|
|
448 |
DEDUP_TOT: self.dedup_total,
|
449 |
}
|
450 |
|
451 |
+
def prepare_text_duplicates(self):
|
452 |
+
if self.tokenized_df is None:
|
453 |
+
self.load_or_prepare_tokenized_df()
|
454 |
+
dup_df = self.tokenized_df[
|
455 |
+
self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
|
456 |
+
self.dup_counts_df = pd.DataFrame(
|
457 |
+
dup_df.pivot_table(
|
458 |
+
columns=[OUR_TEXT_FIELD], aggfunc="size"
|
459 |
+
).sort_values(ascending=False),
|
460 |
+
columns=[CNT],
|
461 |
+
)
|
462 |
+
self.dup_counts_df[OUR_TEXT_FIELD] = self.dup_counts_df.index.copy()
|
463 |
+
|
464 |
+
def load_or_prepare_dataset(self, save=True):
|
465 |
"""
|
466 |
Prepares the HF datasets and data frames containing the untokenized and
|
467 |
tokenized text as well as the label values.
|
468 |
self.tokenized_df is used further for calculating text lengths,
|
469 |
word counts, etc.
|
470 |
Args:
|
|
|
471 |
save: Store the calculated data to disk.
|
472 |
|
473 |
Returns:
|
474 |
|
475 |
"""
|
476 |
logs.info("Doing text dset.")
|
477 |
+
self.load_or_prepare_text_dset(save)
|
478 |
logs.info("Doing tokenized dataframe")
|
479 |
+
self.load_or_prepare_tokenized_df(save)
|
480 |
logs.info("Doing dataset peek")
|
481 |
+
self.load_or_prepare_dset_peek(save)
|
482 |
|
483 |
+
def load_or_prepare_dset_peek(self, save=True):
|
484 |
+
if self.use_cache and exists(self.dset_peek_fid):
|
485 |
with open(self.dset_peek_fid, "r") as f:
|
486 |
self.dset_peek = json.load(f)["dset peek"]
|
487 |
else:
|
|
|
489 |
self.get_base_dataset()
|
490 |
self.dset_peek = self.dset[:100]
|
491 |
if save:
|
492 |
+
write_json({"dset peek": self.dset_peek}, self.dset_peek_fid)
|
493 |
|
494 |
+
def load_or_prepare_tokenized_df(self, save=True):
|
495 |
+
if (self.use_cache and exists(self.tokenized_df_fid)):
|
496 |
self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
|
497 |
else:
|
498 |
# tokenize all text instances
|
|
|
502 |
# save tokenized text
|
503 |
write_df(self.tokenized_df, self.tokenized_df_fid)
|
504 |
|
505 |
+
def load_or_prepare_text_dset(self, save=True):
|
506 |
+
if (self.use_cache and exists(self.text_dset_fid)):
|
507 |
# load extracted text
|
508 |
self.text_dset = load_from_disk(self.text_dset_fid)
|
509 |
logs.warning("Loaded dataset from disk")
|
|
|
532 |
Tokenizes the dataset
|
533 |
:return:
|
534 |
"""
|
535 |
+
if self.text_dset is None:
|
536 |
+
self.load_or_prepare_text_dset()
|
537 |
sent_tokenizer = self.cvec.build_tokenizer()
|
538 |
|
539 |
def tokenize_batch(examples):
|
|
|
563 |
"""
|
564 |
self.label_field = label_field
|
565 |
|
566 |
+
def load_or_prepare_labels(self, save=True):
|
567 |
# TODO: This is in a transitory state for creating fig cache.
|
568 |
# Clean up to be caching and reading everything correctly.
|
569 |
"""
|
570 |
Extracts labels from the Dataset
|
|
|
571 |
:return:
|
572 |
"""
|
573 |
# extracted labels
|
574 |
if len(self.label_field) > 0:
|
575 |
+
if self.use_cache and exists(self.fig_labels_fid):
|
576 |
self.fig_labels = read_plotly(self.fig_labels_fid)
|
577 |
+
elif self.use_cache and exists(self.label_dset_fid):
|
578 |
# load extracted labels
|
579 |
self.label_dset = load_from_disk(self.label_dset_fid)
|
580 |
self.label_df = self.label_dset.to_pandas()
|
|
|
601 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
602 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
603 |
|
604 |
+
def load_or_prepare_npmi_terms(self):
|
605 |
+
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
|
606 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
607 |
|
608 |
+
def load_or_prepare_zipf(self, save=True):
|
609 |
# TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
|
610 |
# when only reading from cache. Either the UI should use it, or it should
|
611 |
# be removed when reading in cache
|
612 |
+
if self.use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
|
613 |
with open(self.zipf_fid, "r") as f:
|
614 |
zipf_dict = json.load(f)
|
615 |
self.z = Zipf()
|
616 |
self.z.load(zipf_dict)
|
617 |
self.zipf_fig = read_plotly(self.zipf_fig_fid)
|
618 |
+
elif self.use_cache and exists(self.zipf_fid):
|
619 |
# TODO: Read zipf data so that the vocab is there.
|
620 |
with open(self.zipf_fid, "r") as f:
|
621 |
zipf_dict = json.load(f)
|
|
|
661 |
self.available_terms = self.dstats.available_terms
|
662 |
logs.info(self.available_terms)
|
663 |
|
664 |
+
def load_or_prepare_npmi_terms(self):
|
665 |
"""
|
666 |
Figures out what identity terms the user can select, based on whether
|
667 |
they occur more than self.min_vocab_count times
|
|
|
668 |
:return: Identity terms occurring at least self.min_vocab_count times.
|
669 |
"""
|
670 |
# TODO: Add the user's ability to select subgroups.
|
671 |
# TODO: Make min_vocab_count here value selectable by the user.
|
672 |
if (
|
673 |
+
self.use_cache
|
674 |
and exists(self.npmi_terms_fid)
|
675 |
and json.load(open(self.npmi_terms_fid))["available terms"] != []
|
676 |
):
|
|
|
693 |
self.available_terms = available_terms
|
694 |
return available_terms
|
695 |
|
696 |
+
def load_or_prepare_joint_npmi(self, subgroup_pair):
|
697 |
"""
|
698 |
Run on-the fly, while the app is already open,
|
699 |
as it depends on the subgroup terms that the user chooses
|
|
|
712 |
subgroup_files = define_subgroup_files(subgroup_pair, self.pmi_cache_path)
|
713 |
# Defines the filenames for the cache files from the selected subgroups.
|
714 |
# Get as much precomputed data as we can.
|
715 |
+
if self.use_cache and exists(joint_npmi_fid):
|
716 |
# When everything is already computed for the selected subgroups.
|
717 |
logs.info("Loading cached joint npmi")
|
718 |
joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
|
|
|
867 |
csv_df.columns = [calc_str]
|
868 |
return csv_df
|
869 |
|
870 |
+
def get_available_terms(self):
|
871 |
+
return self.load_or_prepare_npmi_terms()
|
872 |
|
873 |
def dummy(doc):
|
874 |
return doc
|
data_measurements/streamlit_utils.py
CHANGED
@@ -136,12 +136,12 @@ def expander_general_stats(dstats, column_id):
|
|
136 |
|
137 |
|
138 |
### Show the label distribution from the datasets
|
139 |
-
def expander_label_distribution(
|
140 |
with st.expander(f"Label Distribution{column_id}", expanded=False):
|
141 |
st.caption(
|
142 |
"Use this widget to see how balanced the labels in your dataset are."
|
143 |
)
|
144 |
-
if
|
145 |
st.plotly_chart(fig_labels, use_container_width=True)
|
146 |
else:
|
147 |
st.markdown("No labels were found in the dataset")
|
@@ -285,7 +285,7 @@ def expander_text_duplicates(dstats, column_id):
|
|
285 |
"### Here is the list of all the duplicated items and their counts in your dataset:"
|
286 |
)
|
287 |
# Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
|
288 |
-
if
|
289 |
st.write("There are no duplicates in this dataset! 🥳")
|
290 |
else:
|
291 |
gb = GridOptionsBuilder.from_dataframe(dstats.dup_counts_df)
|
|
|
136 |
|
137 |
|
138 |
### Show the label distribution from the datasets
|
139 |
+
def expander_label_distribution(fig_labels, column_id):
|
140 |
with st.expander(f"Label Distribution{column_id}", expanded=False):
|
141 |
st.caption(
|
142 |
"Use this widget to see how balanced the labels in your dataset are."
|
143 |
)
|
144 |
+
if fig_labels is not None:
|
145 |
st.plotly_chart(fig_labels, use_container_width=True)
|
146 |
else:
|
147 |
st.markdown("No labels were found in the dataset")
|
|
|
285 |
"### Here is the list of all the duplicated items and their counts in your dataset:"
|
286 |
)
|
287 |
# Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
|
288 |
+
if dstats.dup_counts_df is None:
|
289 |
st.write("There are no duplicates in this dataset! 🥳")
|
290 |
else:
|
291 |
gb = GridOptionsBuilder.from_dataframe(dstats.dup_counts_df)
|