jharrison27 commited on
Commit
629f196
1 Parent(s): fe78a4f

fix decorator and words

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -5,10 +5,10 @@ import numpy as np
5
 
6
  # Mock data
7
  mock_words = [
8
- "apple", "banana", "cherry", "date", # Fruits
9
- "car", "truck", "bus", "bicycle", # Vehicles
10
- "red", "blue", "green", "yellow", # Colors
11
- "cat", "dog", "rabbit", "hamster" # Pets
12
  ]
13
 
14
  # Define available models and load them
@@ -18,7 +18,7 @@ models = {
18
  'RoBERTa': 'roberta-base'
19
  }
20
 
21
- @st.cache(allow_output_mutation=True)
22
  def load_models():
23
  pipelines = {}
24
  for name, model_name in models.items():
@@ -28,6 +28,9 @@ def load_models():
28
  pipelines = load_models()
29
 
30
  def embed_words(words, model_name):
 
 
 
31
  embedder = pipelines[model_name]
32
  embeddings = embedder(words)
33
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
@@ -37,17 +40,8 @@ def cluster_words(words, model_name):
37
  kmeans = KMeans(n_clusters=4, random_state=0).fit(embeddings)
38
  clusters = {i: [] for i in range(4)}
39
  for word, label in zip(words, kmeans.labels_):
40
- clusters[label].append(word)
41
-
42
- # Ensure each group has exactly 4 words
43
- for i in range(4):
44
- if len(clusters[i]) > 4:
45
- st.warning(f"Group {i+1} has more than 4 words. Adjusting to show only 4 words.")
46
- clusters[i] = clusters[i][:4]
47
- elif len(clusters[i]) < 4:
48
- st.warning(f"Group {i+1} has less than 4 words. Adjusting by adding placeholder words.")
49
- clusters[i].extend(['N/A'] * (4 - len(clusters[i])))
50
-
51
  return clusters
52
 
53
  def display_clusters(clusters):
@@ -57,13 +51,16 @@ def display_clusters(clusters):
57
 
58
  def main():
59
  st.title("NYT Connections Solver")
60
-
 
 
61
  # Dropdown menu for selecting the embedding model
62
  model_name = st.selectbox("Select Embedding Model", list(models.keys()))
63
 
64
  if st.button("Generate Clusters"):
65
- clusters = cluster_words(mock_words, model_name)
 
66
  display_clusters(clusters)
67
 
68
  if __name__ == "__main__":
69
- main()
 
5
 
6
  # Mock data
7
  mock_words = [
8
+ "apple", "banana", "cherry", "date", # Fruits
9
+ "car", "truck", "bus", "bicycle", # Vehicles
10
+ "red", "blue", "green", "yellow", # Colors
11
+ "cat", "dog", "rabbit", "hamster" # Pets
12
  ]
13
 
14
  # Define available models and load them
 
18
  'RoBERTa': 'roberta-base'
19
  }
20
 
21
+ @st.cache_resource
22
  def load_models():
23
  pipelines = {}
24
  for name, model_name in models.items():
 
28
  pipelines = load_models()
29
 
30
  def embed_words(words, model_name):
31
+ """
32
+ Embed the given words using the specified model and return the averaged embeddings.
33
+ """
34
  embedder = pipelines[model_name]
35
  embeddings = embedder(words)
36
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
 
40
  kmeans = KMeans(n_clusters=4, random_state=0).fit(embeddings)
41
  clusters = {i: [] for i in range(4)}
42
  for word, label in zip(words, kmeans.labels_):
43
+ if len(clusters[label]) < 4:
44
+ clusters[label].append(word)
 
 
 
 
 
 
 
 
 
45
  return clusters
46
 
47
  def display_clusters(clusters):
 
51
 
52
  def main():
53
  st.title("NYT Connections Solver")
54
+ st.write("This app demonstrates solving the NYT Connections game using word embeddings and clustering.")
55
+ st.write("Select an embedding model from the dropdown menu and click 'Generate Clusters' to see the grouped words.")
56
+
57
  # Dropdown menu for selecting the embedding model
58
  model_name = st.selectbox("Select Embedding Model", list(models.keys()))
59
 
60
  if st.button("Generate Clusters"):
61
+ with st.spinner("Generating clusters..."):
62
+ clusters = cluster_words(mock_words, model_name)
63
  display_clusters(clusters)
64
 
65
  if __name__ == "__main__":
66
+ main()