Spaces:
Build error
Build error
taskswithcode
commited on
Commit
•
5fe6115
1
Parent(s):
e4cf805
Fixes
Browse files- app.py +21 -14
- clus_app_clustypes.json +4 -0
- twc_clustering.py +82 -22
app.py
CHANGED
@@ -103,16 +103,16 @@ def load_model(model_name,model_class,load_model_name):
|
|
103 |
|
104 |
|
105 |
@st.experimental_memo
|
106 |
-
def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster):
|
107 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
108 |
-
results = _cluster.cluster(None,texts,embeddings,threshold)
|
109 |
return results
|
110 |
|
111 |
|
112 |
-
def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster):
|
113 |
with st.spinner('Computing vectors for sentences'):
|
114 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
115 |
-
results = cluster.cluster(None,texts,embeddings,threshold)
|
116 |
#st.success("Similarity computation complete")
|
117 |
return results
|
118 |
|
@@ -124,7 +124,7 @@ def get_model_info(model_names,model_name):
|
|
124 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
125 |
|
126 |
|
127 |
-
def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model):
|
128 |
display_area.text("Loading model:" + model_name)
|
129 |
#Note. model_name may get mapped to new name in the call below for custom models
|
130 |
orig_model_name = model_name
|
@@ -140,10 +140,10 @@ def run_test(model_names,model_name,sentences,display_area,threshold,user_upload
|
|
140 |
display_area.text("Model " + model_name + " load complete")
|
141 |
try:
|
142 |
if (user_uploaded):
|
143 |
-
results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
|
144 |
else:
|
145 |
display_area.text("Computing vectors for sentences")
|
146 |
-
results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
|
147 |
display_area.text("Similarity computation complete")
|
148 |
return results
|
149 |
|
@@ -193,16 +193,19 @@ def init_session():
|
|
193 |
st.session_state["model_name"] = "ss_test"
|
194 |
st.session_state["threshold"] = 1.5
|
195 |
st.session_state["file_name"] = "default"
|
|
|
196 |
st.session_state["cluster"] = TWCClustering()
|
197 |
else:
|
198 |
print("Skipping init session")
|
199 |
|
200 |
-
def app_main(app_mode,example_files,model_name_files):
|
201 |
init_session()
|
202 |
with open(example_files) as fp:
|
203 |
example_file_names = json.load(fp)
|
204 |
with open(model_name_files) as fp:
|
205 |
model_names = json.load(fp)
|
|
|
|
|
206 |
curr_use_case = use_case[app_mode].split(".")[0]
|
207 |
st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
|
208 |
st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
|
@@ -215,7 +218,7 @@ def app_main(app_mode,example_files,model_name_files):
|
|
215 |
|
216 |
with st.form('twc_form'):
|
217 |
|
218 |
-
step1_line = "
|
219 |
if (app_mode == DOC_RETRIEVAL):
|
220 |
step1_line += ". The first line is treated as the query"
|
221 |
uploaded_file = st.file_uploader(step1_line, type=".txt")
|
@@ -224,14 +227,17 @@ def app_main(app_mode,example_files,model_name_files):
|
|
224 |
options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
|
225 |
st.write("")
|
226 |
options_arr,markdown_str = construct_model_info_for_display(model_names)
|
227 |
-
selection_label = '
|
228 |
selected_model = st.selectbox(label=selection_label,
|
229 |
options = options_arr, index=0, key = "twc_model")
|
230 |
st.write("")
|
231 |
custom_model_selection = st.text_input("Model not listed above? Type any Huggingface sentence embedding model name ", "",key="custom_model")
|
232 |
hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface sentence embedding models</a><br/><br/><br/></div>"
|
233 |
st.markdown(hf_link_str, unsafe_allow_html=True)
|
234 |
-
threshold = st.number_input('
|
|
|
|
|
|
|
235 |
st.write("")
|
236 |
submit_button = st.form_submit_button('Run')
|
237 |
|
@@ -256,7 +262,8 @@ def app_main(app_mode,example_files,model_name_files):
|
|
256 |
run_model = selected_model
|
257 |
st.session_state["model_name"] = selected_model
|
258 |
st.session_state["threshold"] = threshold
|
259 |
-
|
|
|
260 |
display_area.empty()
|
261 |
with display_area.container():
|
262 |
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
@@ -269,7 +276,7 @@ def app_main(app_mode,example_files,model_name_files):
|
|
269 |
label="Download results as json",
|
270 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
271 |
disabled = False if st.session_state["download_ready"] != None else True,
|
272 |
-
file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
|
273 |
mime='text/json',
|
274 |
key ="download"
|
275 |
)
|
@@ -288,5 +295,5 @@ if __name__ == "__main__":
|
|
288 |
#print("comand line input:",len(sys.argv),str(sys.argv))
|
289 |
#app_main(sys.argv[1],sys.argv[2],sys.argv[3])
|
290 |
#app_main("1","sim_app_examples.json","sim_app_models.json")
|
291 |
-
app_main("3","clus_app_examples.json","clus_app_models.json")
|
292 |
|
|
|
103 |
|
104 |
|
105 |
@st.experimental_memo
|
106 |
+
def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster,clustering_type):
|
107 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
108 |
+
results = _cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
109 |
return results
|
110 |
|
111 |
|
112 |
+
def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster,clustering_type):
|
113 |
with st.spinner('Computing vectors for sentences'):
|
114 |
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
115 |
+
results = cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
116 |
#st.success("Similarity computation complete")
|
117 |
return results
|
118 |
|
|
|
124 |
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
125 |
|
126 |
|
127 |
+
def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model,clustering_type):
|
128 |
display_area.text("Loading model:" + model_name)
|
129 |
#Note. model_name may get mapped to new name in the call below for custom models
|
130 |
orig_model_name = model_name
|
|
|
140 |
display_area.text("Model " + model_name + " load complete")
|
141 |
try:
|
142 |
if (user_uploaded):
|
143 |
+
results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
144 |
else:
|
145 |
display_area.text("Computing vectors for sentences")
|
146 |
+
results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
147 |
display_area.text("Similarity computation complete")
|
148 |
return results
|
149 |
|
|
|
193 |
st.session_state["model_name"] = "ss_test"
|
194 |
st.session_state["threshold"] = 1.5
|
195 |
st.session_state["file_name"] = "default"
|
196 |
+
st.session_state["overlapped"] = "overlapped"
|
197 |
st.session_state["cluster"] = TWCClustering()
|
198 |
else:
|
199 |
print("Skipping init session")
|
200 |
|
201 |
+
def app_main(app_mode,example_files,model_name_files,clus_types):
|
202 |
init_session()
|
203 |
with open(example_files) as fp:
|
204 |
example_file_names = json.load(fp)
|
205 |
with open(model_name_files) as fp:
|
206 |
model_names = json.load(fp)
|
207 |
+
with open(clus_types) as fp:
|
208 |
+
cluster_types = json.load(fp)
|
209 |
curr_use_case = use_case[app_mode].split(".")[0]
|
210 |
st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
|
211 |
st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
|
|
|
218 |
|
219 |
with st.form('twc_form'):
|
220 |
|
221 |
+
step1_line = "Upload text file(one sentence in a line) or choose an example text file below"
|
222 |
if (app_mode == DOC_RETRIEVAL):
|
223 |
step1_line += ". The first line is treated as the query"
|
224 |
uploaded_file = st.file_uploader(step1_line, type=".txt")
|
|
|
227 |
options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
|
228 |
st.write("")
|
229 |
options_arr,markdown_str = construct_model_info_for_display(model_names)
|
230 |
+
selection_label = 'Select Model'
|
231 |
selected_model = st.selectbox(label=selection_label,
|
232 |
options = options_arr, index=0, key = "twc_model")
|
233 |
st.write("")
|
234 |
custom_model_selection = st.text_input("Model not listed above? Type any Huggingface sentence embedding model name ", "",key="custom_model")
|
235 |
hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface sentence embedding models</a><br/><br/><br/></div>"
|
236 |
st.markdown(hf_link_str, unsafe_allow_html=True)
|
237 |
+
threshold = st.number_input('Choose a zscore threshold (number of std devs from mean)',value=st.session_state["threshold"],min_value = 0.0,step=.01)
|
238 |
+
st.write("")
|
239 |
+
clustering_type = st.selectbox(label=f'Select type of clustering',
|
240 |
+
options = list(dict.keys(cluster_types)), index=0, key = "twc_cluster_types")
|
241 |
st.write("")
|
242 |
submit_button = st.form_submit_button('Run')
|
243 |
|
|
|
262 |
run_model = selected_model
|
263 |
st.session_state["model_name"] = selected_model
|
264 |
st.session_state["threshold"] = threshold
|
265 |
+
st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
|
266 |
+
results = run_test(model_names,run_model,sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0),cluster_types[clustering_type]["type"])
|
267 |
display_area.empty()
|
268 |
with display_area.container():
|
269 |
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
|
|
276 |
label="Download results as json",
|
277 |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
278 |
disabled = False if st.session_state["download_ready"] != None else True,
|
279 |
+
file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + st.session_state["overlapped"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
|
280 |
mime='text/json',
|
281 |
key ="download"
|
282 |
)
|
|
|
295 |
#print("comand line input:",len(sys.argv),str(sys.argv))
|
296 |
#app_main(sys.argv[1],sys.argv[2],sys.argv[3])
|
297 |
#app_main("1","sim_app_examples.json","sim_app_models.json")
|
298 |
+
app_main("3","clus_app_examples.json","clus_app_models.json","clus_app_clustypes.json")
|
299 |
|
clus_app_clustypes.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Overlapped clustering (cluster size determined by zscore)": {"type":"overlapped"},
|
3 |
+
"Non-overlapped clustering (overlapped clusters aggregated)":{"type":"non-overlapped"}
|
4 |
+
}
|
twc_clustering.py
CHANGED
@@ -31,27 +31,30 @@ class TWCClustering:
|
|
31 |
picked_arr = []
|
32 |
while (run_index < len(embeddings)):
|
33 |
if (matrix[pivot_index][run_index] >= threshold):
|
34 |
-
|
35 |
-
picked_arr.append({"index":run_index})
|
36 |
run_index += 1
|
37 |
return picked_arr
|
38 |
|
|
|
|
|
|
|
|
|
39 |
def update_picked_dict(self,picked_dict,in_dict):
|
40 |
for key in in_dict:
|
41 |
picked_dict[key] = 1
|
42 |
|
43 |
-
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold):
|
44 |
center_index = pivot_index
|
45 |
center_score = 0
|
46 |
center_dict = {}
|
47 |
for i in range(len(arr)):
|
48 |
-
node_i_index = arr[i]
|
49 |
running_score = 0
|
50 |
temp_dict = {}
|
51 |
for j in range(len(arr)):
|
52 |
-
node_j_index = arr[j]
|
53 |
cosine_dist = matrix[node_i_index][node_j_index]
|
54 |
-
if (cosine_dist < threshold):
|
55 |
continue
|
56 |
running_score += cosine_dist
|
57 |
temp_dict[node_j_index] = cosine_dist
|
@@ -80,8 +83,76 @@ class TWCClustering:
|
|
80 |
bucket_dict[overlap_dict[key]] += 1
|
81 |
sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
|
82 |
return sorted_d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
def cluster(self,output_file,texts,embeddings,threshold
|
|
|
85 |
matrix = self.compute_matrix(embeddings)
|
86 |
mean = np.mean(matrix)
|
87 |
std = np.std(matrix)
|
@@ -95,22 +166,11 @@ class TWCClustering:
|
|
95 |
#print("In clustering:",round(std,2),zscores)
|
96 |
cluster_dict = {}
|
97 |
cluster_dict["clusters"] = []
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
if (i in picked_dict):
|
103 |
-
continue
|
104 |
-
zscore = mean + threshold*std
|
105 |
-
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
|
106 |
-
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore)
|
107 |
-
self.update_picked_dict(picked_dict,cluster_info["neighs"])
|
108 |
-
self.update_overlap_stats(overlap_dict,cluster_info)
|
109 |
-
cluster_dict["clusters"].append(cluster_info)
|
110 |
curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
|
111 |
-
sorted_d = OrderedDict(sorted(overlap_dict.items(), key=lambda kv: kv[1], reverse=True))
|
112 |
-
#print(sorted_d)
|
113 |
-
sorted_d = self.bucket_overlap(overlap_dict)
|
114 |
cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
|
115 |
return cluster_dict
|
116 |
|
|
|
31 |
picked_arr = []
|
32 |
while (run_index < len(embeddings)):
|
33 |
if (matrix[pivot_index][run_index] >= threshold):
|
34 |
+
picked_arr.append(run_index)
|
|
|
35 |
run_index += 1
|
36 |
return picked_arr
|
37 |
|
38 |
+
def update_picked_dict_arr(self,picked_dict,arr):
|
39 |
+
for i in range(len(arr)):
|
40 |
+
picked_dict[arr[i]] = 1
|
41 |
+
|
42 |
def update_picked_dict(self,picked_dict,in_dict):
|
43 |
for key in in_dict:
|
44 |
picked_dict[key] = 1
|
45 |
|
46 |
+
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold,strict_cluster = True):
|
47 |
center_index = pivot_index
|
48 |
center_score = 0
|
49 |
center_dict = {}
|
50 |
for i in range(len(arr)):
|
51 |
+
node_i_index = arr[i]
|
52 |
running_score = 0
|
53 |
temp_dict = {}
|
54 |
for j in range(len(arr)):
|
55 |
+
node_j_index = arr[j]
|
56 |
cosine_dist = matrix[node_i_index][node_j_index]
|
57 |
+
if ((cosine_dist < threshold) and strict_cluster):
|
58 |
continue
|
59 |
running_score += cosine_dist
|
60 |
temp_dict[node_j_index] = cosine_dist
|
|
|
83 |
bucket_dict[overlap_dict[key]] += 1
|
84 |
sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
|
85 |
return sorted_d
|
86 |
+
|
87 |
+
def merge_clusters(self,ref_cluster,curr_cluster):
|
88 |
+
dup_arr = ref_cluster.copy()
|
89 |
+
for j in range(len(curr_cluster)):
|
90 |
+
if (curr_cluster[j] not in dup_arr):
|
91 |
+
ref_cluster.append(curr_cluster[j])
|
92 |
+
|
93 |
+
|
94 |
+
def non_overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
|
95 |
+
picked_dict = {}
|
96 |
+
overlap_dict = {}
|
97 |
+
candidates = []
|
98 |
+
|
99 |
+
for i in range(len(embeddings)):
|
100 |
+
if (i in picked_dict):
|
101 |
+
continue
|
102 |
+
zscore = mean + threshold*std
|
103 |
+
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
|
104 |
+
candidates.append(arr)
|
105 |
+
self.update_picked_dict_arr(picked_dict,arr)
|
106 |
+
|
107 |
+
# Merge arrays to create non-overlapping sets
|
108 |
+
run_index_i = 0
|
109 |
+
while (run_index_i < len(candidates)):
|
110 |
+
ref_cluster = candidates[run_index_i]
|
111 |
+
run_index_j = run_index_i + 1
|
112 |
+
found = False
|
113 |
+
while (run_index_j < len(candidates)):
|
114 |
+
curr_cluster = candidates[run_index_j]
|
115 |
+
for k in range(len(curr_cluster)):
|
116 |
+
if (curr_cluster[k] in ref_cluster):
|
117 |
+
self.merge_clusters(ref_cluster,curr_cluster)
|
118 |
+
candidates.pop(run_index_j)
|
119 |
+
found = True
|
120 |
+
run_index_i = 0
|
121 |
+
break
|
122 |
+
if (found):
|
123 |
+
break
|
124 |
+
else:
|
125 |
+
run_index_j += 1
|
126 |
+
if (not found):
|
127 |
+
run_index_i += 1
|
128 |
+
|
129 |
+
|
130 |
+
zscore = mean + threshold*std
|
131 |
+
for i in range(len(candidates)):
|
132 |
+
arr = candidates[i]
|
133 |
+
cluster_info = self.find_pivot_subgraph(arr[0],arr,matrix,zscore,strict_cluster = False)
|
134 |
+
cluster_dict["clusters"].append(cluster_info)
|
135 |
+
return {}
|
136 |
+
|
137 |
+
def overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
|
138 |
+
picked_dict = {}
|
139 |
+
overlap_dict = {}
|
140 |
+
|
141 |
+
zscore = mean + threshold*std
|
142 |
+
for i in range(len(embeddings)):
|
143 |
+
if (i in picked_dict):
|
144 |
+
continue
|
145 |
+
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
|
146 |
+
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore,strict_cluster = True)
|
147 |
+
self.update_picked_dict(picked_dict,cluster_info["neighs"])
|
148 |
+
self.update_overlap_stats(overlap_dict,cluster_info)
|
149 |
+
cluster_dict["clusters"].append(cluster_info)
|
150 |
+
sorted_d = self.bucket_overlap(overlap_dict)
|
151 |
+
return sorted_d
|
152 |
+
|
153 |
|
154 |
+
def cluster(self,output_file,texts,embeddings,threshold,clustering_type):
|
155 |
+
is_overlapped = True if clustering_type == "overlapped" else False
|
156 |
matrix = self.compute_matrix(embeddings)
|
157 |
mean = np.mean(matrix)
|
158 |
std = np.std(matrix)
|
|
|
166 |
#print("In clustering:",round(std,2),zscores)
|
167 |
cluster_dict = {}
|
168 |
cluster_dict["clusters"] = []
|
169 |
+
if (is_overlapped):
|
170 |
+
sorted_d = self.overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
|
171 |
+
else:
|
172 |
+
sorted_d = self.non_overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
|
|
|
|
|
|
|
174 |
cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
|
175 |
return cluster_dict
|
176 |
|