jharrison27 commited on
Commit
67cc9a7
1 Parent(s): 40dd6aa

add other embedding models

Browse files
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -11,28 +11,51 @@ mock_words = [
11
  "cat", "dog", "rabbit", "hamster" # Pets
12
  ]
13
 
14
- # Embedding model
15
- embedder = pipeline('feature-extraction', model='distilbert-base-uncased')
 
 
 
 
16
 
17
- def embed_words(words):
 
18
  embeddings = embedder(words)
19
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
20
 
21
- def cluster_words(words):
22
- embeddings = embed_words(words)
23
  kmeans = KMeans(n_clusters=4, random_state=0).fit(embeddings)
24
  clusters = {i: [] for i in range(4)}
25
  for word, label in zip(words, kmeans.labels_):
26
  clusters[label].append(word)
 
 
 
 
 
 
 
 
 
 
27
  return clusters
28
 
 
 
 
 
 
29
  def main():
30
  st.title("NYT Connections Solver")
31
-
 
 
 
 
32
  if st.button("Generate Clusters"):
33
- clusters = cluster_words(mock_words)
34
- for i, words in clusters.items():
35
- st.write(f"Group {i+1}: {', '.join(words)}")
36
 
37
  if __name__ == "__main__":
38
  main()
 
11
  "cat", "dog", "rabbit", "hamster" # Pets
12
  ]
13
 
14
+ # Define available models
15
+ models = {
16
+ 'DistilBERT': 'distilbert-base-uncased',
17
+ 'BERT': 'bert-base-uncased',
18
+ 'RoBERTa': 'roberta-base'
19
+ }
20
 
21
+ def embed_words(words, model_name):
22
+ embedder = pipeline('feature-extraction', model=model_name)
23
  embeddings = embedder(words)
24
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
25
 
26
+ def cluster_words(words, model_name):
27
+ embeddings = embed_words(words, model_name)
28
  kmeans = KMeans(n_clusters=4, random_state=0).fit(embeddings)
29
  clusters = {i: [] for i in range(4)}
30
  for word, label in zip(words, kmeans.labels_):
31
  clusters[label].append(word)
32
+
33
+ # Ensure each group has exactly 4 words
34
+ for i in range(4):
35
+ if len(clusters[i]) > 4:
36
+ st.warning(f"Group {i+1} has more than 4 words. Adjusting to show only 4 words.")
37
+ clusters[i] = clusters[i][:4]
38
+ elif len(clusters[i]) < 4:
39
+ st.warning(f"Group {i+1} has less than 4 words. Adjusting by adding placeholder words.")
40
+ clusters[i].extend(['N/A'] * (4 - len(clusters[i])))
41
+
42
  return clusters
43
 
44
+ def display_clusters(clusters):
45
+ for i, words in clusters.items():
46
+ st.markdown(f"### Group {i+1}")
47
+ st.write(", ".join(words))
48
+
49
  def main():
50
  st.title("NYT Connections Solver")
51
+
52
+ # Dropdown menu for selecting the embedding model
53
+ model_name = st.selectbox("Select Embedding Model", list(models.keys()))
54
+ selected_model = models[model_name]
55
+
56
  if st.button("Generate Clusters"):
57
+ clusters = cluster_words(mock_words, selected_model)
58
+ display_clusters(clusters)
 
59
 
60
  if __name__ == "__main__":
61
  main()