Spaces:
Build error
Build error
meg-huggingface
commited on
Commit
•
0b7eeeb
1
Parent(s):
f9936fb
Updating from rollback
Browse files- data_measurements/embeddings.py +322 -219
- data_measurements/streamlit_utils.py +8 -23
data_measurements/embeddings.py
CHANGED
@@ -20,12 +20,14 @@ import plotly.graph_objects as go
|
|
20 |
import torch
|
21 |
import transformers
|
22 |
from datasets import load_from_disk
|
|
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
-
from .dataset_utils import EMBEDDING_FIELD
|
26 |
|
27 |
|
28 |
def sentence_mean_pooling(model_output, attention_mask):
|
|
|
29 |
token_embeddings = model_output[
|
30 |
0
|
31 |
] # First element of model_output contains all token embeddings
|
@@ -38,46 +40,46 @@ def sentence_mean_pooling(model_output, attention_mask):
|
|
38 |
|
39 |
|
40 |
class Embeddings:
|
41 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
"""Item embeddings and clustering"""
|
43 |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
self.node_list = None
|
45 |
self.nid_map = None
|
46 |
-
self.
|
47 |
self.fig_tree = None
|
48 |
self.cached_clusters = {}
|
49 |
-
self.dstats = dstats
|
50 |
-
self.cache_path = dstats.cache_path
|
51 |
-
self.node_list_fid = pjoin(self.cache_path, "node_list.th")
|
52 |
self.use_cache = use_cache
|
53 |
-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
54 |
-
"sentence-transformers/all-mpnet-base-v2"
|
55 |
-
)
|
56 |
-
self.model = transformers.AutoModel.from_pretrained(
|
57 |
-
"sentence-transformers/all-mpnet-base-v2"
|
58 |
-
).to(self.device)
|
59 |
-
|
60 |
-
def make_text_embeddings(self):
|
61 |
-
embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset")
|
62 |
-
if self.use_cache and exists(embeddings_dset_fid):
|
63 |
-
self.embeddings_dset = load_from_disk(embeddings_dset_fid)
|
64 |
-
else:
|
65 |
-
self.embeddings_dset = self.make_embeddings()
|
66 |
-
self.embeddings_dset.save_to_disk(embeddings_dset_fid)
|
67 |
-
|
68 |
-
def make_hierarchical_clustering(self):
|
69 |
-
if self.use_cache and exists(self.node_list_fid):
|
70 |
-
self.node_list = torch.load(self.node_list_fid)
|
71 |
-
else:
|
72 |
-
self.make_text_embeddings()
|
73 |
-
self.node_list = self.fast_cluster(self.embeddings_dset, EMBEDDING_FIELD)
|
74 |
-
torch.save(self.node_list, self.node_list_fid)
|
75 |
-
self.nid_map = dict(
|
76 |
-
[(node["nid"], nid) for nid, node in enumerate(self.node_list)]
|
77 |
-
)
|
78 |
-
self.fig_tree = make_tree_plot(self.node_list, self.dstats.text_dset)
|
79 |
|
80 |
def compute_sentence_embeddings(self, sentences):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
batch = self.tokenizer(
|
82 |
sentences, padding=True, truncation=True, return_tensors="pt"
|
83 |
)
|
@@ -91,212 +93,70 @@ class Embeddings:
|
|
91 |
return sentence_embeds
|
92 |
|
93 |
def make_embeddings(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def batch_embed_sentences(sentences):
|
95 |
return {
|
96 |
EMBEDDING_FIELD: [
|
97 |
embed.tolist()
|
98 |
for embed in self.compute_sentence_embeddings(
|
99 |
-
sentences[
|
100 |
)
|
101 |
]
|
102 |
}
|
103 |
|
104 |
-
|
105 |
batch_embed_sentences,
|
106 |
batched=True,
|
107 |
batch_size=32,
|
108 |
-
remove_columns=[self.
|
109 |
-
)
|
110 |
-
|
111 |
-
return text_dset_embeds
|
112 |
-
|
113 |
-
@staticmethod
|
114 |
-
def prepare_merges(embeddings, batch_size, low_thres=0.5):
|
115 |
-
top_idx_pre = torch.cat(
|
116 |
-
[torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
|
117 |
-
)
|
118 |
-
top_val_all = torch.Tensor(0, batch_size)
|
119 |
-
top_idx_all = torch.LongTensor(0, batch_size)
|
120 |
-
n_batches = math.ceil(len(embeddings) / batch_size)
|
121 |
-
for b in tqdm(range(n_batches)):
|
122 |
-
cos_scores = torch.mm(
|
123 |
-
embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
|
124 |
-
)
|
125 |
-
for i in range(cos_scores.shape[0]):
|
126 |
-
cos_scores[i, (b * batch_size) + i :] = -1
|
127 |
-
top_val_large, top_idx_large = cos_scores.topk(
|
128 |
-
k=batch_size, dim=-1, largest=True
|
129 |
-
)
|
130 |
-
top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
|
131 |
-
top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
|
132 |
-
|
133 |
-
all_merges = torch.cat(
|
134 |
-
[
|
135 |
-
top_idx_pre[top_val_all > low_thres][:, None],
|
136 |
-
top_idx_all[top_val_all > low_thres][:, None],
|
137 |
-
],
|
138 |
-
dim=1,
|
139 |
)
|
140 |
-
all_merge_scores = top_val_all[top_val_all > low_thres]
|
141 |
-
return (all_merges, all_merge_scores)
|
142 |
|
143 |
-
|
144 |
-
def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
|
145 |
-
merge_ids = (all_merge_scores <= previous_thres) * (
|
146 |
-
all_merge_scores > current_thres
|
147 |
-
)
|
148 |
-
merges = all_merges[merge_ids]
|
149 |
-
for a, b in merges.tolist():
|
150 |
-
node_a = nodes[a]
|
151 |
-
while node_a["parent_id"] != -1:
|
152 |
-
node_a = nodes[node_a["parent_id"]]
|
153 |
-
node_b = nodes[b]
|
154 |
-
while node_b["parent_id"] != -1:
|
155 |
-
node_b = nodes[node_b["parent_id"]]
|
156 |
-
if node_a["nid"] == node_b["nid"]:
|
157 |
-
continue
|
158 |
-
else:
|
159 |
-
# merge if threshold allows
|
160 |
-
if (node_a["depth"] + node_b["depth"]) > 0 and min(
|
161 |
-
node_a["merge_threshold"], node_b["merge_threshold"]
|
162 |
-
) == current_thres:
|
163 |
-
merge_to = None
|
164 |
-
merge_from = None
|
165 |
-
if node_a["nid"] < node_b["nid"]:
|
166 |
-
merge_from = node_a
|
167 |
-
merge_to = node_b
|
168 |
-
if node_a["nid"] > node_b["nid"]:
|
169 |
-
merge_from = node_b
|
170 |
-
merge_to = node_a
|
171 |
-
merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
|
172 |
-
merge_to["weight"] += merge_from["weight"]
|
173 |
-
merge_to["children_ids"] += (
|
174 |
-
merge_from["children_ids"]
|
175 |
-
if merge_from["depth"] > 0
|
176 |
-
else [merge_from["nid"]]
|
177 |
-
)
|
178 |
-
for cid in merge_from["children_ids"]:
|
179 |
-
nodes[cid]["parent_id"] = merge_to["nid"]
|
180 |
-
merge_from["parent_id"] = merge_to["nid"]
|
181 |
-
# else new node
|
182 |
-
else:
|
183 |
-
new_nid = len(nodes)
|
184 |
-
new_node = {
|
185 |
-
"nid": new_nid,
|
186 |
-
"parent_id": -1,
|
187 |
-
"depth": max(node_a["depth"], node_b["depth"]) + 1,
|
188 |
-
"weight": node_a["weight"] + node_b["weight"],
|
189 |
-
"children": [],
|
190 |
-
"children_ids": [node_a["nid"], node_b["nid"]],
|
191 |
-
"example_ids": [],
|
192 |
-
"merge_threshold": current_thres,
|
193 |
-
}
|
194 |
-
node_a["parent_id"] = new_nid
|
195 |
-
node_b["parent_id"] = new_nid
|
196 |
-
nodes += [new_node]
|
197 |
-
return nodes
|
198 |
|
199 |
-
def
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
reverse=True,
|
207 |
-
)
|
208 |
-
if node["depth"] > 0:
|
209 |
-
node["example_ids"] = [
|
210 |
-
eid for child in node["children"] for eid in child["example_ids"]
|
211 |
-
]
|
212 |
-
node["children"] = [
|
213 |
-
child for child in node["children"] if child["weight"] >= min_cluster_size
|
214 |
-
]
|
215 |
-
assert node["weight"] == len(node["example_ids"]), print(node)
|
216 |
-
return node
|
217 |
|
218 |
-
def
|
219 |
self,
|
220 |
-
text_dset_embeds,
|
221 |
-
embedding_field,
|
222 |
batch_size=1000,
|
|
|
223 |
min_cluster_size=10,
|
224 |
-
low_thres=0.5,
|
225 |
):
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
{
|
234 |
-
"nid": nid,
|
235 |
-
"parent_id": -1,
|
236 |
-
"depth": 0,
|
237 |
-
"weight": 1,
|
238 |
-
"children": [],
|
239 |
-
"children_ids": [],
|
240 |
-
"example_ids": [nid],
|
241 |
-
"merge_threshold": 1.0,
|
242 |
-
}
|
243 |
-
for nid in range(embeddings.shape[0])
|
244 |
-
]
|
245 |
-
# one level per threshold range
|
246 |
-
for i in range(10):
|
247 |
-
p_thres = 1 - i * 0.05
|
248 |
-
c_thres = 0.95 - i * 0.05
|
249 |
-
nodes = self.merge_nodes(
|
250 |
-
nodes, c_thres, p_thres, all_merges, all_merge_scores
|
251 |
)
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
"depth": max([node["depth"] for node in root_children]) + 1,
|
262 |
-
"weight": sum([node["weight"] for node in root_children]),
|
263 |
-
"children": [],
|
264 |
-
"children_ids": [node["nid"] for node in root_children],
|
265 |
-
"example_ids": [],
|
266 |
-
"merge_threshold": -1.0,
|
267 |
-
}
|
268 |
-
nodes += [root]
|
269 |
-
for node in root_children:
|
270 |
-
node["parent_id"] = root["nid"]
|
271 |
-
# finalize tree
|
272 |
-
tree = self.finalize_node(root, nodes, min_cluster_size)
|
273 |
-
node_list = []
|
274 |
-
|
275 |
-
def rec_map_nodes(node, node_list):
|
276 |
-
node_list += [node]
|
277 |
-
for child in node["children"]:
|
278 |
-
rec_map_nodes(child, node_list)
|
279 |
-
|
280 |
-
rec_map_nodes(tree, node_list)
|
281 |
-
# get centroids and distances
|
282 |
-
for node in node_list:
|
283 |
-
node_embeds = embeddings[node["example_ids"]]
|
284 |
-
node["centroid"] = node_embeds.sum(dim=0)
|
285 |
-
node["centroid"] /= node["centroid"].norm()
|
286 |
-
node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"])
|
287 |
-
node["sorted_examples_centroid"] = sorted(
|
288 |
-
[
|
289 |
-
(eid, edp.item())
|
290 |
-
for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"])
|
291 |
-
],
|
292 |
-
key=lambda x: x[1],
|
293 |
-
reverse=True,
|
294 |
)
|
295 |
-
|
296 |
|
297 |
def find_cluster_beam(self, sentence, beam_size=20):
|
298 |
"""
|
299 |
-
This function finds the `beam_size`
|
300 |
proposed sentence and returns the full path from the root to the cluster
|
301 |
along with the dot product between the sentence embedding and the
|
302 |
cluster centroid
|
@@ -365,25 +225,268 @@ class Embeddings:
|
|
365 |
)[:beam_size]
|
366 |
|
367 |
|
368 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
|
370 |
|
371 |
for nid, node in enumerate(node_list):
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
node["label"] = node.get(
|
373 |
"label",
|
374 |
f"{nid:2d} - {node['weight']:5d} items <br>"
|
375 |
+ "<br>".join(
|
376 |
[
|
377 |
-
"
|
378 |
-
for txt in
|
379 |
-
set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
|
380 |
-
)[:5]
|
381 |
]
|
382 |
),
|
383 |
)
|
384 |
|
385 |
# make plot nodes
|
386 |
-
# TODO: something more efficient than set to remove duplicates
|
387 |
labels = [node["label"] for node in node_list]
|
388 |
|
389 |
root = node_list[0]
|
|
|
20 |
import torch
|
21 |
import transformers
|
22 |
from datasets import load_from_disk
|
23 |
+
from plotly.io import read_json
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
+
from .dataset_utils import EMBEDDING_FIELD
|
27 |
|
28 |
|
29 |
def sentence_mean_pooling(model_output, attention_mask):
|
30 |
+
"""Mean pooling of token embeddings for a sentence."""
|
31 |
token_embeddings = model_output[
|
32 |
0
|
33 |
] # First element of model_output contains all token embeddings
|
|
|
40 |
|
41 |
|
42 |
class Embeddings:
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
dstats=None,
|
46 |
+
text_dset=None,
|
47 |
+
text_field_name="text",
|
48 |
+
cache_path="",
|
49 |
+
use_cache=False,
|
50 |
+
):
|
51 |
"""Item embeddings and clustering"""
|
52 |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
53 |
+
self.model_name = "sentence-transformers/all-mpnet-base-v2"
|
54 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
55 |
+
self.model = transformers.AutoModel.from_pretrained(self.model_name).to(
|
56 |
+
self.device
|
57 |
+
)
|
58 |
+
self.text_dset = text_dset if dstats is None else dstats.text_dset
|
59 |
+
self.text_field_name = (
|
60 |
+
text_field_name if dstats is None else dstats.our_text_field
|
61 |
+
)
|
62 |
+
self.cache_path = cache_path if dstats is None else dstats.cache_path
|
63 |
+
self.embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset")
|
64 |
+
self.embeddings_dset = None
|
65 |
+
self.node_list_fid = pjoin(self.cache_path, "node_list.th")
|
66 |
self.node_list = None
|
67 |
self.nid_map = None
|
68 |
+
self.fig_tree_fid = pjoin(self.cache_path, "node_figure.json")
|
69 |
self.fig_tree = None
|
70 |
self.cached_clusters = {}
|
|
|
|
|
|
|
71 |
self.use_cache = use_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
def compute_sentence_embeddings(self, sentences):
|
74 |
+
"""
|
75 |
+
Takes a list of sentences and computes their embeddings
|
76 |
+
using self.tokenizer and self.model (with output dimension D)
|
77 |
+
followed by mean pooling of the token representations and normalization
|
78 |
+
Args:
|
79 |
+
sentences ([string]): list of N input sentences
|
80 |
+
Returns:
|
81 |
+
torch.Tensor: sentence embeddings, dimension NxD
|
82 |
+
"""
|
83 |
batch = self.tokenizer(
|
84 |
sentences, padding=True, truncation=True, return_tensors="pt"
|
85 |
)
|
|
|
93 |
return sentence_embeds
|
94 |
|
95 |
def make_embeddings(self):
|
96 |
+
"""
|
97 |
+
Batch computes the embeddings of the Dataset self.text_dset,
|
98 |
+
using the field self.text_field_name as input.
|
99 |
+
Returns:
|
100 |
+
Dataset: HF dataset object with a single EMBEDDING_FIELD field
|
101 |
+
corresponding to the embeddings (list of floats)
|
102 |
+
"""
|
103 |
+
|
104 |
def batch_embed_sentences(sentences):
|
105 |
return {
|
106 |
EMBEDDING_FIELD: [
|
107 |
embed.tolist()
|
108 |
for embed in self.compute_sentence_embeddings(
|
109 |
+
sentences[self.text_field_name]
|
110 |
)
|
111 |
]
|
112 |
}
|
113 |
|
114 |
+
self.embeddings_dset = self.text_dset.map(
|
115 |
batch_embed_sentences,
|
116 |
batched=True,
|
117 |
batch_size=32,
|
118 |
+
remove_columns=[self.text_field_name],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
)
|
|
|
|
|
120 |
|
121 |
+
return self.embeddings_dset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
def make_text_embeddings(self):
|
124 |
+
"""Load embeddings dataset from cache or compute it."""
|
125 |
+
if self.use_cache and exists(self.embeddings_dset_fid):
|
126 |
+
self.embeddings_dset = load_from_disk(self.embeddings_dset_fid)
|
127 |
+
else:
|
128 |
+
self.embeddings_dset = self.make_embeddings()
|
129 |
+
self.embeddings_dset.save_to_disk(self.embeddings_dset_fid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
+
def make_hierarchical_clustering(
|
132 |
self,
|
|
|
|
|
133 |
batch_size=1000,
|
134 |
+
approx_neighbors=1000,
|
135 |
min_cluster_size=10,
|
|
|
136 |
):
|
137 |
+
if self.use_cache and exists(self.node_list_fid):
|
138 |
+
self.node_list, self.nid_map = torch.load(self.node_list_fid)
|
139 |
+
else:
|
140 |
+
self.make_text_embeddings()
|
141 |
+
embeddings = torch.Tensor(self.embeddings_dset[EMBEDDING_FIELD])
|
142 |
+
self.node_list = fast_cluster(
|
143 |
+
embeddings, batch_size, approx_neighbors, min_cluster_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
)
|
145 |
+
self.nid_map = dict(
|
146 |
+
[(node["nid"], nid) for nid, node in enumerate(self.node_list)]
|
147 |
+
)
|
148 |
+
torch.save((self.node_list, self.nid_map), self.node_list_fid)
|
149 |
+
if self.use_cache and exists(self.fig_tree_fid):
|
150 |
+
self.fig_tree = read_json(self.fig_tree_fid)
|
151 |
+
else:
|
152 |
+
self.fig_tree = make_tree_plot(
|
153 |
+
self.node_list, self.text_dset, self.text_field_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
)
|
155 |
+
self.fig_tree.write_json(self.fig_tree_fid)
|
156 |
|
157 |
def find_cluster_beam(self, sentence, beam_size=20):
|
158 |
"""
|
159 |
+
This function finds the `beam_size` leaf clusters that are closest to the
|
160 |
proposed sentence and returns the full path from the root to the cluster
|
161 |
along with the dot product between the sentence embedding and the
|
162 |
cluster centroid
|
|
|
225 |
)[:beam_size]
|
226 |
|
227 |
|
228 |
+
def prepare_merges(embeddings, batch_size=1000, approx_neighbors=1000, low_thres=0.5):
|
229 |
+
"""
|
230 |
+
Prepares an initial list of merges for hierarchical
|
231 |
+
clustering. First compute the `approx_neighbors` nearest neighbors,
|
232 |
+
then propose a merge for any two points that are closer than `low_thres`
|
233 |
+
|
234 |
+
Note that if a point has more than `approx_neighbors` neighbors
|
235 |
+
closer than `low_thres`, this approach will miss some of those merges
|
236 |
+
|
237 |
+
Args:
|
238 |
+
embeddings (toch.Tensor): Tensor of sentence embeddings - dimension NxD
|
239 |
+
batch_size (int): compute nearest neighbors of `batch_size` points at a time
|
240 |
+
approx_neighbors (int): only keep `approx_neighbors` nearest neighbors of a point
|
241 |
+
low_thres (float): only return merges where the dot product is greater than `low_thres`
|
242 |
+
Returns:
|
243 |
+
torch.LongTensor: proposed merges ([i, j] with i>j) - dimension: Mx2
|
244 |
+
torch.Tensor: merge scores - dimension M
|
245 |
+
"""
|
246 |
+
top_idx_pre = torch.cat(
|
247 |
+
[torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
|
248 |
+
)
|
249 |
+
top_val_all = torch.Tensor(0, approx_neighbors)
|
250 |
+
top_idx_all = torch.LongTensor(0, approx_neighbors)
|
251 |
+
n_batches = math.ceil(len(embeddings) / batch_size)
|
252 |
+
for b in tqdm(range(n_batches)):
|
253 |
+
# TODO: batch across second dimension
|
254 |
+
cos_scores = torch.mm(
|
255 |
+
embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
|
256 |
+
)
|
257 |
+
for i in range(cos_scores.shape[0]):
|
258 |
+
cos_scores[i, (b * batch_size) + i :] = -1
|
259 |
+
top_val_large, top_idx_large = cos_scores.topk(
|
260 |
+
k=approx_neighbors, dim=-1, largest=True
|
261 |
+
)
|
262 |
+
top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
|
263 |
+
top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
|
264 |
+
max_neighbor_dist = top_val_large[:, -1].max().item()
|
265 |
+
if max_neighbor_dist > low_thres:
|
266 |
+
print(
|
267 |
+
f"WARNING: with the current set of neireast neighbor, the farthest is {max_neighbor_dist}"
|
268 |
+
)
|
269 |
+
|
270 |
+
all_merges = torch.cat(
|
271 |
+
[
|
272 |
+
top_idx_pre[top_val_all > low_thres][:, None],
|
273 |
+
top_idx_all[top_val_all > low_thres][:, None],
|
274 |
+
],
|
275 |
+
dim=1,
|
276 |
+
)
|
277 |
+
all_merge_scores = top_val_all[top_val_all > low_thres]
|
278 |
+
|
279 |
+
return (all_merges, all_merge_scores)
|
280 |
+
|
281 |
+
|
282 |
+
def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
|
283 |
+
"""
|
284 |
+
Merge all nodes if the max dot product between any of their descendants
|
285 |
+
is greater than current_thres.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
nodes ([dict]): list of dicts representing the current set of nodes
|
289 |
+
current_thres (float): merge all nodes closer than current_thres
|
290 |
+
previous_thres (float): nodes closer than previous_thres are already merged
|
291 |
+
all_merges (torch.LongTensor): proposed merges ([i, j] with i>j) - dimension: Mx2
|
292 |
+
all_merge_scores (torch.Tensor): merge scores - dimension M
|
293 |
+
Returns:
|
294 |
+
[dict]: extended list with the newly created internal nodes
|
295 |
+
"""
|
296 |
+
merge_ids = (all_merge_scores <= previous_thres) * (
|
297 |
+
all_merge_scores > current_thres
|
298 |
+
)
|
299 |
+
if merge_ids.sum().item() > 0:
|
300 |
+
merges = all_merges[merge_ids]
|
301 |
+
for a, b in merges.tolist():
|
302 |
+
node_a = nodes[a]
|
303 |
+
while node_a["parent_id"] != -1:
|
304 |
+
node_a = nodes[node_a["parent_id"]]
|
305 |
+
node_b = nodes[b]
|
306 |
+
while node_b["parent_id"] != -1:
|
307 |
+
node_b = nodes[node_b["parent_id"]]
|
308 |
+
if node_a["nid"] == node_b["nid"]:
|
309 |
+
continue
|
310 |
+
else:
|
311 |
+
# merge if threshold allows
|
312 |
+
if (node_a["depth"] + node_b["depth"]) > 0 and min(
|
313 |
+
node_a["merge_threshold"], node_b["merge_threshold"]
|
314 |
+
) == current_thres:
|
315 |
+
merge_to = None
|
316 |
+
merge_from = None
|
317 |
+
if node_a["nid"] < node_b["nid"]:
|
318 |
+
merge_from = node_a
|
319 |
+
merge_to = node_b
|
320 |
+
if node_a["nid"] > node_b["nid"]:
|
321 |
+
merge_from = node_b
|
322 |
+
merge_to = node_a
|
323 |
+
merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
|
324 |
+
merge_to["weight"] += merge_from["weight"]
|
325 |
+
merge_to["children_ids"] += (
|
326 |
+
merge_from["children_ids"]
|
327 |
+
if merge_from["depth"] > 0
|
328 |
+
else [merge_from["nid"]]
|
329 |
+
)
|
330 |
+
for cid in merge_from["children_ids"]:
|
331 |
+
nodes[cid]["parent_id"] = merge_to["nid"]
|
332 |
+
merge_from["parent_id"] = merge_to["nid"]
|
333 |
+
# else new node
|
334 |
+
else:
|
335 |
+
new_nid = len(nodes)
|
336 |
+
new_node = {
|
337 |
+
"nid": new_nid,
|
338 |
+
"parent_id": -1,
|
339 |
+
"depth": max(node_a["depth"], node_b["depth"]) + 1,
|
340 |
+
"weight": node_a["weight"] + node_b["weight"],
|
341 |
+
"children": [],
|
342 |
+
"children_ids": [node_a["nid"], node_b["nid"]],
|
343 |
+
"example_ids": [],
|
344 |
+
"merge_threshold": current_thres,
|
345 |
+
}
|
346 |
+
node_a["parent_id"] = new_nid
|
347 |
+
node_b["parent_id"] = new_nid
|
348 |
+
nodes += [new_node]
|
349 |
+
return nodes
|
350 |
+
|
351 |
+
|
352 |
+
def finalize_node(node, nodes, min_cluster_size):
|
353 |
+
"""Post-process nodes to sort children by descending weight,
|
354 |
+
get full list of leaves in the sub-tree, and direct links
|
355 |
+
to the cildren nodes, then recurses to all children.
|
356 |
+
|
357 |
+
Nodes with fewer than `min_cluster_size` descendants are collapsed
|
358 |
+
into a single leaf.
|
359 |
+
"""
|
360 |
+
node["children"] = sorted(
|
361 |
+
[
|
362 |
+
finalize_node(nodes[cid], nodes, min_cluster_size)
|
363 |
+
for cid in node["children_ids"]
|
364 |
+
],
|
365 |
+
key=lambda x: x["weight"],
|
366 |
+
reverse=True,
|
367 |
+
)
|
368 |
+
if node["depth"] > 0:
|
369 |
+
node["example_ids"] = [
|
370 |
+
eid for child in node["children"] for eid in child["example_ids"]
|
371 |
+
]
|
372 |
+
node["children"] = [
|
373 |
+
child for child in node["children"] if child["weight"] >= min_cluster_size
|
374 |
+
]
|
375 |
+
assert node["weight"] == len(node["example_ids"]), print(node)
|
376 |
+
return node
|
377 |
+
|
378 |
+
|
379 |
+
def fast_cluster(
|
380 |
+
embeddings,
|
381 |
+
batch_size=1000,
|
382 |
+
approx_neighbors=1000,
|
383 |
+
min_cluster_size=10,
|
384 |
+
low_thres=0.5,
|
385 |
+
):
|
386 |
+
"""
|
387 |
+
Computes an approximate hierarchical clustering based on example
|
388 |
+
embeddings. The join criterion is min clustering, i.e. two clusters
|
389 |
+
are joined if any pair of their descendants are closer than a threshold
|
390 |
+
|
391 |
+
The approximate comes from the fact that only the `approx_neighbors` nearest
|
392 |
+
neighbors of an example are considered for merges
|
393 |
+
"""
|
394 |
+
batch_size = min(embeddings.shape[0], batch_size)
|
395 |
+
all_merges, all_merge_scores = prepare_merges(
|
396 |
+
embeddings, batch_size, approx_neighbors, low_thres
|
397 |
+
)
|
398 |
+
# prepare leaves
|
399 |
+
nodes = [
|
400 |
+
{
|
401 |
+
"nid": nid,
|
402 |
+
"parent_id": -1,
|
403 |
+
"depth": 0,
|
404 |
+
"weight": 1,
|
405 |
+
"children": [],
|
406 |
+
"children_ids": [],
|
407 |
+
"example_ids": [nid],
|
408 |
+
"merge_threshold": 1.0,
|
409 |
+
}
|
410 |
+
for nid in range(embeddings.shape[0])
|
411 |
+
]
|
412 |
+
# one level per threshold range
|
413 |
+
for i in range(10):
|
414 |
+
p_thres = 1 - i * 0.05
|
415 |
+
c_thres = 0.95 - i * 0.05
|
416 |
+
nodes = merge_nodes(nodes, c_thres, p_thres, all_merges, all_merge_scores)
|
417 |
+
# make root
|
418 |
+
root_children = [
|
419 |
+
node
|
420 |
+
for node in nodes
|
421 |
+
if node["parent_id"] == -1 and node["weight"] >= min_cluster_size
|
422 |
+
]
|
423 |
+
root = {
|
424 |
+
"nid": len(nodes),
|
425 |
+
"parent_id": -1,
|
426 |
+
"depth": max([node["depth"] for node in root_children]) + 1,
|
427 |
+
"weight": sum([node["weight"] for node in root_children]),
|
428 |
+
"children": [],
|
429 |
+
"children_ids": [node["nid"] for node in root_children],
|
430 |
+
"example_ids": [],
|
431 |
+
"merge_threshold": -1.0,
|
432 |
+
}
|
433 |
+
nodes += [root]
|
434 |
+
for node in root_children:
|
435 |
+
node["parent_id"] = root["nid"]
|
436 |
+
# finalize tree
|
437 |
+
tree = finalize_node(root, nodes, min_cluster_size)
|
438 |
+
node_list = []
|
439 |
+
|
440 |
+
def rec_map_nodes(node, node_list):
|
441 |
+
node_list += [node]
|
442 |
+
for child in node["children"]:
|
443 |
+
rec_map_nodes(child, node_list)
|
444 |
+
|
445 |
+
rec_map_nodes(tree, node_list)
|
446 |
+
# get centroids and distances
|
447 |
+
for node in node_list:
|
448 |
+
node_embeds = embeddings[node["example_ids"]]
|
449 |
+
node["centroid"] = node_embeds.sum(dim=0)
|
450 |
+
node["centroid"] /= node["centroid"].norm()
|
451 |
+
node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"])
|
452 |
+
node["sorted_examples_centroid"] = sorted(
|
453 |
+
[
|
454 |
+
(eid, edp.item())
|
455 |
+
for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"])
|
456 |
+
],
|
457 |
+
key=lambda x: x[1],
|
458 |
+
reverse=True,
|
459 |
+
)
|
460 |
+
return node_list
|
461 |
+
|
462 |
+
|
463 |
+
def make_tree_plot(node_list, text_dset, text_field_name):
|
464 |
+
"""
|
465 |
+
Makes a graphical representation of the tree encoded
|
466 |
+
in node-list. The hover label for each node shows the number
|
467 |
+
of descendants and the 5 examples that are closest to the centroid
|
468 |
+
"""
|
469 |
nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
|
470 |
|
471 |
for nid, node in enumerate(node_list):
|
472 |
+
# get list of
|
473 |
+
node_examples = {}
|
474 |
+
for sid, score in node["sorted_examples_centroid"]:
|
475 |
+
node_examples[text_dset[sid][text_field_name]] = score
|
476 |
+
if len(node_examples) >= 5:
|
477 |
+
break
|
478 |
node["label"] = node.get(
|
479 |
"label",
|
480 |
f"{nid:2d} - {node['weight']:5d} items <br>"
|
481 |
+ "<br>".join(
|
482 |
[
|
483 |
+
f" {score:.2f} > {txt[:64]}" + ("..." if len(txt) >= 63 else "")
|
484 |
+
for txt, score in node_examples.items()
|
|
|
|
|
485 |
]
|
486 |
),
|
487 |
)
|
488 |
|
489 |
# make plot nodes
|
|
|
490 |
labels = [node["label"] for node in node_list]
|
491 |
|
492 |
root = node_list[0]
|
data_measurements/streamlit_utils.py
CHANGED
@@ -21,7 +21,6 @@ from st_aggrid import AgGrid, GridOptionsBuilder
|
|
21 |
|
22 |
from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
|
23 |
|
24 |
-
|
25 |
def sidebar_header():
|
26 |
st.sidebar.markdown(
|
27 |
"""
|
@@ -167,7 +166,11 @@ def expander_text_lengths(dstats, column_id):
|
|
167 |
st.markdown(
|
168 |
"### Here is the relative frequency of different text lengths in your dataset:"
|
169 |
)
|
170 |
-
|
|
|
|
|
|
|
|
|
171 |
st.markdown(
|
172 |
"The average length of text instances is **"
|
173 |
+ str(dstats.avg_length)
|
@@ -175,19 +178,11 @@ def expander_text_lengths(dstats, column_id):
|
|
175 |
+ str(dstats.std_length)
|
176 |
+ "**."
|
177 |
)
|
178 |
-
|
179 |
-
start_id_show_lengths = st.slider(
|
180 |
-
f"Show the shortest sentences{column_id} starting at:",
|
181 |
-
0,
|
182 |
-
dstats.num_uniq_lengths,
|
183 |
-
value=0,
|
184 |
-
step=1,
|
185 |
-
)
|
186 |
-
|
187 |
# This is quite a large file and is breaking our ability to navigate the app development.
|
188 |
# Just passing if it's not already there for launch v0
|
189 |
if dstats.length_df is not None:
|
190 |
-
st.
|
|
|
191 |
|
192 |
|
193 |
### Third, use a sentence embedding model
|
@@ -285,17 +280,7 @@ def expander_text_duplicates(dstats, column_id):
|
|
285 |
if dstats.dup_counts_df is None or dstats.dup_counts_df.empty:
|
286 |
st.write("There are no duplicates in this dataset! 🥳")
|
287 |
else:
|
288 |
-
|
289 |
-
gb.configure_column(
|
290 |
-
f"text{column_id}",
|
291 |
-
wrapText=True,
|
292 |
-
resizable=True,
|
293 |
-
autoHeight=True,
|
294 |
-
min_column_width=85,
|
295 |
-
use_container_width=True,
|
296 |
-
)
|
297 |
-
go = gb.build()
|
298 |
-
AgGrid(dstats.dup_counts_df, gridOptions=go)
|
299 |
|
300 |
|
301 |
def expander_npmi_description(min_vocab):
|
|
|
21 |
|
22 |
from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
|
23 |
|
|
|
24 |
def sidebar_header():
|
25 |
st.sidebar.markdown(
|
26 |
"""
|
|
|
166 |
st.markdown(
|
167 |
"### Here is the relative frequency of different text lengths in your dataset:"
|
168 |
)
|
169 |
+
#TODO: figure out more elegant way to do this:
|
170 |
+
try:
|
171 |
+
st.image(dstats.fig_tok_length_png)
|
172 |
+
except:
|
173 |
+
st.pyplot(dstats.fig_tok_length, use_container_width=True)
|
174 |
st.markdown(
|
175 |
"The average length of text instances is **"
|
176 |
+ str(dstats.avg_length)
|
|
|
178 |
+ str(dstats.std_length)
|
179 |
+ "**."
|
180 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
# This is quite a large file and is breaking our ability to navigate the app development.
|
182 |
# Just passing if it's not already there for launch v0
|
183 |
if dstats.length_df is not None:
|
184 |
+
start_id_show_lengths= st.selectbox("Show examples of length:", sorted(dstats.length_df["length"].unique().tolist()))
|
185 |
+
st.table(dstats.length_df[dstats.length_df["length"] == start_id_show_lengths].set_index("length"))
|
186 |
|
187 |
|
188 |
### Third, use a sentence embedding model
|
|
|
280 |
if dstats.dup_counts_df is None or dstats.dup_counts_df.empty:
|
281 |
st.write("There are no duplicates in this dataset! 🥳")
|
282 |
else:
|
283 |
+
st.dataframe(dstats.dup_counts_df.reset_index(drop=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
|
286 |
def expander_npmi_description(min_vocab):
|