Spaces:
Build error
Build error
meg-huggingface
commited on
Commit
•
a2ae370
1
Parent(s):
335424f
More modularizing; npmi and labels
Browse files- app.py +5 -12
- data_measurements/dataset_statistics.py +20 -20
- data_measurements/streamlit_utils.py +4 -5
app.py
CHANGED
@@ -118,9 +118,8 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
|
|
118 |
if show_embeddings:
|
119 |
logs.warning("Loading Embeddings")
|
120 |
dstats.load_or_prepare_embeddings()
|
121 |
-
|
122 |
-
|
123 |
-
dstats.load_or_prepare_npmi_terms()
|
124 |
logs.warning("Loading Zipf")
|
125 |
dstats.load_or_prepare_zipf()
|
126 |
return dstats
|
@@ -156,6 +155,8 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
|
|
156 |
# Embeddings widget
|
157 |
dstats.load_or_prepare_embeddings()
|
158 |
dstats.load_or_prepare_text_duplicates()
|
|
|
|
|
159 |
|
160 |
def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
|
161 |
"""
|
@@ -179,17 +180,9 @@ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=T
|
|
179 |
st_utils.expander_label_distribution(dstats.fig_labels, column_id)
|
180 |
st_utils.expander_text_lengths(dstats, column_id)
|
181 |
st_utils.expander_text_duplicates(dstats, column_id)
|
182 |
-
|
183 |
-
# We do the loading of these after the others in order to have some time
|
184 |
-
# to compute while the user works with the details above.
|
185 |
# Uses an interaction; handled a bit differently than other widgets.
|
186 |
logs.info("showing npmi widget")
|
187 |
-
npmi_stats
|
188 |
-
dstats, use_cache=use_cache
|
189 |
-
)
|
190 |
-
available_terms = npmi_stats.get_available_terms()
|
191 |
-
st_utils.npmi_widget(
|
192 |
-
column_id, available_terms, npmi_stats, _MIN_VOCAB_COUNT)
|
193 |
logs.info("showing zipf")
|
194 |
st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
|
195 |
if show_embeddings:
|
|
|
118 |
if show_embeddings:
|
119 |
logs.warning("Loading Embeddings")
|
120 |
dstats.load_or_prepare_embeddings()
|
121 |
+
logs.warning("Loading nPMI")
|
122 |
+
dstats.load_or_prepare_npmi()
|
|
|
123 |
logs.warning("Loading Zipf")
|
124 |
dstats.load_or_prepare_zipf()
|
125 |
return dstats
|
|
|
155 |
# Embeddings widget
|
156 |
dstats.load_or_prepare_embeddings()
|
157 |
dstats.load_or_prepare_text_duplicates()
|
158 |
+
dstats.load_or_prepare_npmi()
|
159 |
+
dstats.load_or_prepare_zipf()
|
160 |
|
161 |
def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
|
162 |
"""
|
|
|
180 |
st_utils.expander_label_distribution(dstats.fig_labels, column_id)
|
181 |
st_utils.expander_text_lengths(dstats, column_id)
|
182 |
st_utils.expander_text_duplicates(dstats, column_id)
|
|
|
|
|
|
|
183 |
# Uses an interaction; handled a bit differently than other widgets.
|
184 |
logs.info("showing npmi widget")
|
185 |
+
st_utils.npmi_widget(dstats.npmi_stats, _MIN_VOCAB_COUNT, column_id)
|
|
|
|
|
|
|
|
|
|
|
186 |
logs.info("showing zipf")
|
187 |
st_utils.expander_zipf(dstats.z, dstats.zipf_fig, column_id)
|
188 |
if show_embeddings:
|
data_measurements/dataset_statistics.py
CHANGED
@@ -231,10 +231,6 @@ class DatasetStatisticsCacheClass:
|
|
231 |
# nPMI
|
232 |
# Holds a nPMIStatisticsCacheClass object
|
233 |
self.npmi_stats = None
|
234 |
-
# TODO: Users ideally can type in whatever words they want.
|
235 |
-
self.termlist = _IDENTITY_TERMS
|
236 |
-
# termlist terms that are available more than _MIN_VOCAB_COUNT times
|
237 |
-
self.available_terms = _IDENTITY_TERMS
|
238 |
# TODO: Have lowercase be an option for a user to set.
|
239 |
self.to_lowercase = True
|
240 |
# The minimum amount of times a word should occur to be included in
|
@@ -627,24 +623,27 @@ class DatasetStatisticsCacheClass:
|
|
627 |
if save:
|
628 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
629 |
else:
|
630 |
-
self.
|
631 |
-
self.label_dset = self.dset.map(
|
632 |
-
lambda examples: extract_field(
|
633 |
-
examples, self.label_field, OUR_LABEL_FIELD
|
634 |
-
),
|
635 |
-
batched=True,
|
636 |
-
remove_columns=list(self.dset.features),
|
637 |
-
)
|
638 |
-
self.label_df = self.label_dset.to_pandas()
|
639 |
-
self.fig_labels = make_fig_labels(
|
640 |
-
self.label_df, self.label_names, OUR_LABEL_FIELD
|
641 |
-
)
|
642 |
if save:
|
643 |
# save extracted label instances
|
644 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
645 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
646 |
|
647 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
|
649 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
650 |
|
@@ -693,7 +692,10 @@ class nPMIStatisticsCacheClass:
|
|
693 |
# We need to preprocess everything.
|
694 |
mkdir(self.pmi_cache_path)
|
695 |
self.joint_npmi_df_dict = {}
|
696 |
-
|
|
|
|
|
|
|
697 |
logs.info(self.termlist)
|
698 |
self.use_cache = use_cache
|
699 |
# TODO: Let users specify
|
@@ -701,8 +703,6 @@ class nPMIStatisticsCacheClass:
|
|
701 |
self.min_vocab_count = self.dstats.min_vocab_count
|
702 |
self.subgroup_files = {}
|
703 |
self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
|
704 |
-
self.available_terms = self.dstats.available_terms
|
705 |
-
logs.info(self.available_terms)
|
706 |
|
707 |
def load_or_prepare_npmi_terms(self):
|
708 |
"""
|
|
|
231 |
# nPMI
|
232 |
# Holds a nPMIStatisticsCacheClass object
|
233 |
self.npmi_stats = None
|
|
|
|
|
|
|
|
|
234 |
# TODO: Have lowercase be an option for a user to set.
|
235 |
self.to_lowercase = True
|
236 |
# The minimum amount of times a word should occur to be included in
|
|
|
623 |
if save:
|
624 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
625 |
else:
|
626 |
+
self.prepare_labels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
if save:
|
628 |
# save extracted label instances
|
629 |
self.label_dset.save_to_disk(self.label_dset_fid)
|
630 |
write_plotly(self.fig_labels, self.fig_labels_fid)
|
631 |
|
632 |
+
def prepare_labels(self):
|
633 |
+
self.get_base_dataset()
|
634 |
+
self.label_dset = self.dset.map(
|
635 |
+
lambda examples: extract_field(
|
636 |
+
examples, self.label_field, OUR_LABEL_FIELD
|
637 |
+
),
|
638 |
+
batched=True,
|
639 |
+
remove_columns=list(self.dset.features),
|
640 |
+
)
|
641 |
+
self.label_df = self.label_dset.to_pandas()
|
642 |
+
self.fig_labels = make_fig_labels(
|
643 |
+
self.label_df, self.label_names, OUR_LABEL_FIELD
|
644 |
+
)
|
645 |
+
|
646 |
+
def load_or_prepare_npmi(self):
|
647 |
self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=self.use_cache)
|
648 |
self.npmi_stats.load_or_prepare_npmi_terms()
|
649 |
|
|
|
692 |
# We need to preprocess everything.
|
693 |
mkdir(self.pmi_cache_path)
|
694 |
self.joint_npmi_df_dict = {}
|
695 |
+
# TODO: Users ideally can type in whatever words they want.
|
696 |
+
self.termlist = _IDENTITY_TERMS
|
697 |
+
# termlist terms that are available more than _MIN_VOCAB_COUNT times
|
698 |
+
self.available_terms = _IDENTITY_TERMS
|
699 |
logs.info(self.termlist)
|
700 |
self.use_cache = use_cache
|
701 |
# TODO: Let users specify
|
|
|
703 |
self.min_vocab_count = self.dstats.min_vocab_count
|
704 |
self.subgroup_files = {}
|
705 |
self.npmi_terms_fid = pjoin(self.dstats.cache_path, "npmi_terms.json")
|
|
|
|
|
706 |
|
707 |
def load_or_prepare_npmi_terms(self):
|
708 |
"""
|
data_measurements/streamlit_utils.py
CHANGED
@@ -273,7 +273,6 @@ def expander_text_duplicates(dstats, column_id):
|
|
273 |
st.write(
|
274 |
"### Here is the list of all the duplicated items and their counts in your dataset:"
|
275 |
)
|
276 |
-
# Eh...adding 1 because otherwise it looks too weird for duplicate counts when the value is just 1.
|
277 |
if dstats.dup_counts_df is None:
|
278 |
st.write("There are no duplicates in this dataset! 🥳")
|
279 |
else:
|
@@ -393,7 +392,7 @@ with an ideal α value of 1."""
|
|
393 |
|
394 |
|
395 |
### Finally finally finally, show nPMI stuff.
|
396 |
-
def npmi_widget(
|
397 |
"""
|
398 |
Part of the main app, but uses a user interaction so pulled out as its own f'n.
|
399 |
:param use_cache:
|
@@ -403,16 +402,16 @@ def npmi_widget(column_id, available_terms, npmi_stats, min_vocab):
|
|
403 |
:return:
|
404 |
"""
|
405 |
with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
|
406 |
-
if len(available_terms) > 0:
|
407 |
expander_npmi_description(min_vocab)
|
408 |
st.markdown("-----")
|
409 |
term1 = st.selectbox(
|
410 |
f"What is the first term you want to select?{column_id}",
|
411 |
-
available_terms,
|
412 |
)
|
413 |
term2 = st.selectbox(
|
414 |
f"What is the second term you want to select?{column_id}",
|
415 |
-
reversed(available_terms),
|
416 |
)
|
417 |
# We calculate/grab nPMI data based on a canonical (alphabetic)
|
418 |
# subgroup ordering.
|
|
|
273 |
st.write(
|
274 |
"### Here is the list of all the duplicated items and their counts in your dataset:"
|
275 |
)
|
|
|
276 |
if dstats.dup_counts_df is None:
|
277 |
st.write("There are no duplicates in this dataset! 🥳")
|
278 |
else:
|
|
|
392 |
|
393 |
|
394 |
### Finally finally finally, show nPMI stuff.
|
395 |
+
def npmi_widget(npmi_stats, min_vocab, column_id):
|
396 |
"""
|
397 |
Part of the main app, but uses a user interaction so pulled out as its own f'n.
|
398 |
:param use_cache:
|
|
|
402 |
:return:
|
403 |
"""
|
404 |
with st.expander(f"Word Association{column_id}: nPMI", expanded=False):
|
405 |
+
if len(npmi_stats.available_terms) > 0:
|
406 |
expander_npmi_description(min_vocab)
|
407 |
st.markdown("-----")
|
408 |
term1 = st.selectbox(
|
409 |
f"What is the first term you want to select?{column_id}",
|
410 |
+
npmi_stats.available_terms,
|
411 |
)
|
412 |
term2 = st.selectbox(
|
413 |
f"What is the second term you want to select?{column_id}",
|
414 |
+
reversed(npmi_stats.available_terms),
|
415 |
)
|
416 |
# We calculate/grab nPMI data based on a canonical (alphabetic)
|
417 |
# subgroup ordering.
|