jharrison27's picture
revert changes
8fd248e
raw
history blame
No virus
3.34 kB
import streamlit as st
from transformers import pipeline
from sklearn.cluster import KMeans
import numpy as np
# Mock data
mock_words = [
"apple", "banana", "cherry", "date", # Fruits
"car", "truck", "bus", "bicycle", # Vehicles
"red", "blue", "green", "yellow", # Colors
"cat", "dog", "rabbit", "hamster" # Pets
]
# Define available models and load them
models = {
'DistilBERT': 'distilbert-base-uncased',
'BERT': 'bert-base-uncased',
'RoBERTa': 'roberta-base'
}
@st.cache_resource
def load_models():
pipelines = {}
for name, model_name in models.items():
pipelines[name] = pipeline('feature-extraction', model=model_name)
return pipelines
pipelines = load_models()
def embed_words(words, model_name):
"""
Embed the given words using the specified model and return the averaged embeddings.
"""
embedder = pipelines[model_name]
embeddings = embedder(words)
return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
def iterative_clustering(words, model_name):
remaining_words = words[:]
grouped_words = []
while len(remaining_words) >= 4:
embeddings = embed_words(remaining_words, model_name)
kmeans = KMeans(n_clusters=min(4, len(remaining_words) // 4), random_state=0).fit(embeddings)
clusters = {i: [] for i in range(kmeans.n_clusters)}
for word, label in zip(remaining_words, kmeans.labels_):
if len(clusters[label]) < 4:
clusters[label].append(word)
# Select the most cohesive cluster
best_cluster, best_idx = select_most_cohesive_cluster(clusters, kmeans, embeddings)
# Store the best cluster and remove those words
grouped_words.append(best_cluster)
remaining_words = [word for word in remaining_words if word not in best_cluster]
return grouped_words
def select_most_cohesive_cluster(clusters, kmeans_model, embeddings):
min_distance = float('inf')
best_cluster = None
best_idx = -1
for idx, cluster in clusters.items():
if len(cluster) == 4:
cluster_embeddings = embeddings[[i for i, label in enumerate(kmeans_model.labels_) if label == idx]]
centroid = kmeans_model.cluster_centers_[idx]
distance = np.mean(np.linalg.norm(cluster_embeddings - centroid, axis=1))
if distance < min_distance:
min_distance = distance
best_cluster = cluster
best_idx = idx
return best_cluster, best_idx
def display_clusters(clusters):
for i, words in enumerate(clusters):
st.markdown(f"### Group {i+1}")
st.write(", ".join(words))
def main():
st.title("NYT Connections Solver")
st.write("This app demonstrates solving the NYT Connections game using word embeddings and clustering.")
st.write("Select an embedding model from the dropdown menu and click 'Generate Clusters' to see the grouped words.")
# Dropdown menu for selecting the embedding model
model_name = st.selectbox("Select Embedding Model", list(models.keys()))
if st.button("Generate Clusters"):
with st.spinner("Generating clusters..."):
clusters = iterative_clustering(mock_words, model_name)
display_clusters(clusters)
if __name__ == "__main__":
main()