Spaces:
Runtime error
Runtime error
jharrison27
commited on
Commit
•
629f196
1
Parent(s):
fe78a4f
fix decorator and words
Browse files
app.py
CHANGED
@@ -5,10 +5,10 @@ import numpy as np
|
|
5 |
|
6 |
# Mock data
|
7 |
mock_words = [
|
8 |
-
"apple", "banana", "cherry", "date",
|
9 |
-
"car", "truck", "bus", "bicycle",
|
10 |
-
"red", "blue", "green", "yellow",
|
11 |
-
"cat", "dog", "rabbit", "hamster"
|
12 |
]
|
13 |
|
14 |
# Define available models and load them
|
@@ -18,7 +18,7 @@ models = {
|
|
18 |
'RoBERTa': 'roberta-base'
|
19 |
}
|
20 |
|
21 |
-
@st.
|
22 |
def load_models():
|
23 |
pipelines = {}
|
24 |
for name, model_name in models.items():
|
@@ -28,6 +28,9 @@ def load_models():
|
|
28 |
pipelines = load_models()
|
29 |
|
30 |
def embed_words(words, model_name):
|
|
|
|
|
|
|
31 |
embedder = pipelines[model_name]
|
32 |
embeddings = embedder(words)
|
33 |
return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
|
@@ -37,17 +40,8 @@ def cluster_words(words, model_name):
|
|
37 |
kmeans = KMeans(n_clusters=4, random_state=0).fit(embeddings)
|
38 |
clusters = {i: [] for i in range(4)}
|
39 |
for word, label in zip(words, kmeans.labels_):
|
40 |
-
clusters[label]
|
41 |
-
|
42 |
-
# Ensure each group has exactly 4 words
|
43 |
-
for i in range(4):
|
44 |
-
if len(clusters[i]) > 4:
|
45 |
-
st.warning(f"Group {i+1} has more than 4 words. Adjusting to show only 4 words.")
|
46 |
-
clusters[i] = clusters[i][:4]
|
47 |
-
elif len(clusters[i]) < 4:
|
48 |
-
st.warning(f"Group {i+1} has less than 4 words. Adjusting by adding placeholder words.")
|
49 |
-
clusters[i].extend(['N/A'] * (4 - len(clusters[i])))
|
50 |
-
|
51 |
return clusters
|
52 |
|
53 |
def display_clusters(clusters):
|
@@ -57,13 +51,16 @@ def display_clusters(clusters):
|
|
57 |
|
58 |
def main():
|
59 |
st.title("NYT Connections Solver")
|
60 |
-
|
|
|
|
|
61 |
# Dropdown menu for selecting the embedding model
|
62 |
model_name = st.selectbox("Select Embedding Model", list(models.keys()))
|
63 |
|
64 |
if st.button("Generate Clusters"):
|
65 |
-
|
|
|
66 |
display_clusters(clusters)
|
67 |
|
68 |
if __name__ == "__main__":
|
69 |
-
main()
|
|
|
5 |
|
6 |
# Mock data
|
7 |
mock_words = [
|
8 |
+
"apple", "banana", "cherry", "date", # Fruits
|
9 |
+
"car", "truck", "bus", "bicycle", # Vehicles
|
10 |
+
"red", "blue", "green", "yellow", # Colors
|
11 |
+
"cat", "dog", "rabbit", "hamster" # Pets
|
12 |
]
|
13 |
|
14 |
# Define available models and load them
|
|
|
18 |
'RoBERTa': 'roberta-base'
|
19 |
}
|
20 |
|
21 |
+
@st.cache_resource
|
22 |
def load_models():
|
23 |
pipelines = {}
|
24 |
for name, model_name in models.items():
|
|
|
28 |
pipelines = load_models()
|
29 |
|
30 |
def embed_words(words, model_name):
|
31 |
+
"""
|
32 |
+
Embed the given words using the specified model and return the averaged embeddings.
|
33 |
+
"""
|
34 |
embedder = pipelines[model_name]
|
35 |
embeddings = embedder(words)
|
36 |
return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
|
|
|
40 |
kmeans = KMeans(n_clusters=4, random_state=0).fit(embeddings)
|
41 |
clusters = {i: [] for i in range(4)}
|
42 |
for word, label in zip(words, kmeans.labels_):
|
43 |
+
if len(clusters[label]) < 4:
|
44 |
+
clusters[label].append(word)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
return clusters
|
46 |
|
47 |
def display_clusters(clusters):
|
|
|
51 |
|
52 |
def main():
|
53 |
st.title("NYT Connections Solver")
|
54 |
+
st.write("This app demonstrates solving the NYT Connections game using word embeddings and clustering.")
|
55 |
+
st.write("Select an embedding model from the dropdown menu and click 'Generate Clusters' to see the grouped words.")
|
56 |
+
|
57 |
# Dropdown menu for selecting the embedding model
|
58 |
model_name = st.selectbox("Select Embedding Model", list(models.keys()))
|
59 |
|
60 |
if st.button("Generate Clusters"):
|
61 |
+
with st.spinner("Generating clusters..."):
|
62 |
+
clusters = cluster_words(mock_words, model_name)
|
63 |
display_clusters(clusters)
|
64 |
|
65 |
if __name__ == "__main__":
|
66 |
+
main()
|