Spaces:
Runtime error
Runtime error
import streamlit as st | |
import nltk | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
from scipy.spatial.distance import cosine | |
import numpy as np | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
def plot_heatmap(labels, heatmap, rotation=90): | |
sns.set(font_scale=1.2) | |
fig, ax = plt.subplots() | |
g = sns.heatmap( | |
heatmap, | |
xticklabels=labels, | |
yticklabels=labels, | |
vmin=-1, | |
vmax=1, | |
cmap="coolwarm") | |
g.set_xticklabels(labels, rotation=rotation) | |
g.set_title("Textual Similarity") | |
st.pyplot(fig) | |
#plt.show() | |
st.header("Sentence Similarity Demo") | |
st.markdown("This demo uses the sentence_transformers library to plot sentence similarity between a list of sentences. Change the text below and try for yourself!") | |
# Streamlit text boxes | |
text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\nThe sun is very bright.\nI hear that the universe is very large.\nToday is Tuesday.") | |
# Model setup | |
model = SentenceTransformer('paraphrase-distilroberta-base-v1') | |
nltk.download('punkt') | |
# Run model | |
if text: | |
sentences = nltk.tokenize.sent_tokenize(text) | |
embed = model.encode(sentences) | |
sim = np.zeros([len(embed), len(embed)]) | |
for i,em in enumerate(embed): | |
for j,ea in enumerate(embed): | |
sim[i][j] = 1.0-cosine(em,ea) | |
plot_heatmap(sentences, sim) | |