Spaces:
Runtime error
Runtime error
File size: 2,637 Bytes
8e8b7d6 0ad84e1 6efa95a 8e8b7d6 2486cd2 8e8b7d6 10d694b 8e8b7d6 b394295 2486cd2 b394295 dc21a0e 8e8b7d6 dc21a0e 8e8b7d6 dc21a0e 8e8b7d6 b394295 8e8b7d6 30b1966 2486cd2 8e8b7d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import streamlit as st
import nltk
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import tensorflow as tf
import tensorflow_hub as hub
def cluster_examples(messages, embed, nc=3):
km = KMeans(
n_clusters=nc, init='random',
n_init=10, max_iter=300,
tol=1e-04, random_state=0
)
km = km.fit_predict(embed)
for n in range(nc):
idxs = [i for i in range(len(km)) if km[i] == n]
ms = [messages[i] for i in idxs]
st.markdown ("CLUSTER : %d"%n)
for m in ms:
st.markdown (m)
def plot_heatmap(labels, heatmap, rotation=90):
sns.set(font_scale=1.2)
fig, ax = plt.subplots()
g = sns.heatmap(
heatmap,
xticklabels=labels,
yticklabels=labels,
vmin=-1,
vmax=1,
cmap="coolwarm")
g.set_xticklabels(labels, rotation=rotation)
g.set_title("Textual Similarity")
st.pyplot(fig)
#plt.show()
st.header("Sentence Similarity Demo")
st.markdown("This demo uses the sentence_transformers library to plot sentence similarity between a list of sentences. Change the text below and try for yourself!")
st.markdown("NOTE: this demo is public - please don't enter confidential text")
# Streamlit text boxes
text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\nThe sun is very bright.\nI hear that the universe is very large.\nToday is Tuesday.")
nc = st.slider('Select a number of clusters:', min_value=1, max_value=15, value=3)
model_type = st.radio("Choose model:", ('Sentence Transformer', 'Universal Sentence Encoder'), index=0)
# Model setup
if model_type == "Sentence Transformer":
model = SentenceTransformer('paraphrase-distilroberta-base-v1')
elif model_type == "Universal Sentence Encoder":
model_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
model = hub.load(model_url)
nltk.download('punkt')
# Run model
if text:
sentences = nltk.tokenize.sent_tokenize(text)
if model_type == "Sentence Transformer":
embed = model.encode(sentences)
elif model_type == "Universal Sentence Encoder":
embed = model(sentences).numpy()
sim = np.zeros([len(embed), len(embed)])
for i,em in enumerate(embed):
for j,ea in enumerate(embed):
sim[i][j] = 1.0-cosine(em,ea)
st.subheader("Similarity Heatmap")
plot_heatmap(sentences, sim)
st.subheader("Results from K-Means Clustering")
cluster_examples(sentences, embed, nc)
|