Catherine Breslin commited on
Commit
dc21a0e
1 Parent(s): 30b1966

Adding in model choice

Browse files
Files changed (2) hide show
  1. app.py +12 -3
  2. requirements.txt +2 -0
app.py CHANGED
@@ -47,15 +47,24 @@ text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\
47
 
48
  nc = st.slider('Select a number of clusters', min_value=1, max_value=15, value=3)
49
 
 
 
50
  # Model setup
51
- model = SentenceTransformer('paraphrase-distilroberta-base-v1')
 
 
 
 
 
52
  nltk.download('punkt')
53
 
54
  # Run model
55
  if text:
56
  sentences = nltk.tokenize.sent_tokenize(text)
57
- embed = model.encode(sentences)
58
-
 
 
59
  sim = np.zeros([len(embed), len(embed)])
60
  for i,em in enumerate(embed):
61
  for j,ea in enumerate(embed):
 
47
 
48
  nc = st.slider('Select a number of clusters', min_value=1, max_value=15, value=3)
49
 
50
+ model_type = st.radio("Choose model", ('Sentence Transformer', 'Universal Sentence Encoder'), index=0)
51
+
52
  # Model setup
53
+ if model_type == "Sentence Transformer":
54
+ model = SentenceTransformer('paraphrase-distilroberta-base-v1')
55
+ elif model_type == "Universal Sentence Encoder":
56
+ model_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
57
+ model = hub.load(model_url)
58
+
59
  nltk.download('punkt')
60
 
61
  # Run model
62
  if text:
63
  sentences = nltk.tokenize.sent_tokenize(text)
64
+ if model_type == "Sentence Transformer":
65
+ embed = model.encode(sentences)
66
+ elif model_type == "Universal Sentence Encoder":
67
+ embed = model(sentences).numpy()
68
  sim = np.zeros([len(embed), len(embed)])
69
  for i,em in enumerate(embed):
70
  for j,ea in enumerate(embed):
requirements.txt CHANGED
@@ -7,3 +7,5 @@ numpy
7
  seaborn
8
  matplotlib
9
  sklearn
 
 
 
7
  seaborn
8
  matplotlib
9
  sklearn
10
+ tensorflow_hub
11
+ tensorflow