somewheresy commited on
Commit
4fa8c7b
1 Parent(s): a5f4da5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -48
app.py CHANGED
@@ -8,14 +8,56 @@ from sklearn.cluster import KMeans
8
  import plotly.graph_objects as go
9
  import time
10
  import logging
11
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Additional libraries for querying
14
  from FlagEmbedding import FlagModel
15
 
16
  # Global variables and dataset loading
17
  global dataset_name
18
- dataset_name = 'somewheresystems/dataclysm-arxiv'
 
 
 
 
 
 
 
19
  st.session_state.dataclysm_arxiv = load_dataset(dataset_name, split="train")
20
  total_samples = len(st.session_state.dataclysm_arxiv)
21
 
@@ -83,20 +125,69 @@ def perform_tsne(embeddings):
83
 
84
  def perform_clustering(df, tsne_results):
85
  start_time = time.time()
86
- # Perform KMeans clustering
87
- logging.info('Performing k-means clustering...')
88
  # Step 3: Visualization with Plotly
89
- df['tsne-3d-one'] = tsne_results[:,0]
90
- df['tsne-3d-two'] = tsne_results[:,1]
91
- df['tsne-3d-three'] = tsne_results[:,2]
92
-
93
- # Perform KMeans clustering
94
- kmeans = KMeans(n_clusters=16) # Change the number of clusters as needed
95
- df['cluster'] = kmeans.fit_predict(df[['tsne-3d-one', 'tsne-3d-two', 'tsne-3d-three']])
 
 
96
  end_time = time.time() # End timing
97
- st.sidebar.text(f'k-means clustering completed in {end_time - start_time:.3f} seconds')
98
  return df
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def main():
101
  # Custom CSS
102
  custom_css = """
@@ -112,47 +203,126 @@ def main():
112
  color: #F8F8F8; /* Set the font color to F8F8F8 */
113
  }
114
  /* Add your CSS styles here */
 
 
 
 
 
115
  h1 {
116
  text-align: center;
117
  }
118
  h2,h3,h4 {
119
  text-align: justify;
120
- font-size: 8px
 
 
 
121
  }
122
  body {
123
- text-align: justify;
 
124
  }
 
125
  .stSlider .css-1cpxqw2 {
126
  background: #202020;
 
 
 
 
 
127
  }
128
  .stButton > button {
129
  background-color: #202020;
130
- width: 100%;
131
- border: none;
 
 
132
  padding: 10px 24px;
133
- border-radius: 5px;
134
  font-size: 16px;
135
  font-weight: bold;
 
 
 
 
 
 
 
 
 
 
136
  }
137
  .reportview-container .main .block-container {
138
- padding: 2rem;
139
  background-color: #202020;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  }
 
141
  </style>
142
  """
143
 
144
  # Inject custom CSS with markdown
145
  st.markdown(custom_css, unsafe_allow_html=True)
 
146
  st.sidebar.markdown(
147
- f'<img src="https://www.somewhere.systems/S2-white-logo.png" style="float: bottom-left; width: 32px; height: 32px; opacity: 1.0; animation: fadein 2s;">',
148
  unsafe_allow_html=True
149
  )
150
- st.sidebar.title('Spatial Search Engine')
151
 
152
  # Check if data needs to be loaded
153
  if 'data_loaded' not in st.session_state or not st.session_state.data_loaded:
154
  # User input for number of samples
155
- num_samples = st.sidebar.slider('Select number of samples', 1000, total_samples, 1000)
156
 
157
  if st.sidebar.button('Initialize'):
158
  st.sidebar.text('Initializing data pipeline...')
@@ -171,8 +341,6 @@ def main():
171
  print(f"FAISS index for {column_name} added.")
172
 
173
  return dataset
174
-
175
-
176
 
177
  # Load data and perform t-SNE and clustering
178
  df, embeddings = load_data(num_samples)
@@ -209,21 +377,21 @@ def main():
209
  marker=dict(
210
  size=1,
211
  color=df['cluster'],
212
- colorscale='Viridis',
213
- opacity=0.8
214
  )
215
  )])
 
 
 
 
216
 
217
  fig.update_layout(
218
- plot_bgcolor='#202020',
 
219
  height=800,
220
  margin=dict(l=0, r=0, b=0, t=0),
221
- scene=dict(
222
- xaxis=dict(showbackground=True, backgroundcolor="#000000"),
223
- yaxis=dict(showbackground=True, backgroundcolor="#000000"),
224
- zaxis=dict(showbackground=True, backgroundcolor="#000000"),
225
- ),
226
- scene_camera=dict(eye=dict(x=0.001, y=0.001, z=0.001))
227
  )
228
  st.session_state.fig = fig
229
 
@@ -236,8 +404,19 @@ def main():
236
  if 'df' in st.session_state:
237
  # Sidebar for querying
238
  with st.sidebar:
239
- st.sidebar.markdown("### Query Embeddings")
240
- query = st.text_input("Enter your query:")
 
 
 
 
 
 
 
 
 
 
 
241
  if st.button("Search"):
242
  # Define the model
243
  print("Initializing model...")
@@ -248,7 +427,7 @@ def main():
248
 
249
  query_embedding = model.encode([query])
250
  # Retrieve examples by title similarity (or abstract, depending on your preference)
251
- scores_title, retrieved_examples_title = st.session_state.dataclysm_title_indexed.get_nearest_examples('title_embedding', query_embedding, k=10)
252
  df_query = pd.DataFrame(retrieved_examples_title)
253
  df_query['proximity'] = scores_title
254
  df_query = df_query.sort_values(by='proximity', ascending=True)
@@ -257,19 +436,17 @@ def main():
257
  # Fix the <a href link> to display properly
258
  df_query['URL'] = df_query['id'].apply(lambda x: f'<a href="https://arxiv.org/abs/{x}" target="_blank">Link</a>')
259
  st.sidebar.markdown(df_query[['title', 'proximity', 'id']].to_html(escape=False), unsafe_allow_html=True)
260
- st.sidebar.markdown("# Detailed View")
261
- selected_index = st.sidebar.selectbox("Select Key", st.session_state.df.id)
262
 
263
- # Display metadata for the selected article
264
- selected_row = st.session_state.df[st.session_state.df['id'] == selected_index].iloc[0]
265
- st.markdown(f"### Title\n{selected_row['title']}", unsafe_allow_html=True)
266
- st.markdown(f"### Abstract\n{selected_row['abstract']}", unsafe_allow_html=True)
267
- st.markdown(f"[Read the full paper](https://arxiv.org/abs/{selected_row['id']})", unsafe_allow_html=True)
268
- st.markdown(f"[Download PDF](https://arxiv.org/pdf/{selected_row['id']})", unsafe_allow_html=True)
269
-
270
-
271
-
272
- if __name__ == "__main__":
273
- main()
274
 
 
 
 
 
 
275
 
 
 
 
8
  import plotly.graph_objects as go
9
  import time
10
  import logging
11
+ from sklearn.cluster import HDBSCAN
12
+
13
+
14
+ BACKGROUND_COLOR = 'black'
15
+ COLOR = 'white'
16
+
17
+ def set_page_container_style(
18
+ max_width: int = 10000, max_width_100_percent: bool = False,
19
+ padding_top: int = 1, padding_right: int = 10, padding_left: int = 1, padding_bottom: int = 10,
20
+ color: str = COLOR, background_color: str = BACKGROUND_COLOR,
21
+ ):
22
+ if max_width_100_percent:
23
+ max_width_str = f'max-width: 100%;'
24
+ else:
25
+ max_width_str = f'max-width: {max_width}px;'
26
+ st.markdown(
27
+ f'''
28
+ <style>
29
+ .reportview-container .css-1lcbmhc .css-1outpf7 {{
30
+ padding-top: 35px;
31
+ }}
32
+ .reportview-container .main .block-container {{
33
+ {max_width_str}
34
+ padding-top: {padding_top}rem;
35
+ padding-right: {padding_right}rem;
36
+ padding-left: {padding_left}rem;
37
+ padding-bottom: {padding_bottom}rem;
38
+ }}
39
+ .reportview-container .main {{
40
+ color: {color};
41
+ background-color: {background_color};
42
+ }}
43
+ </style>
44
+ ''',
45
+ unsafe_allow_html=True,
46
+ )
47
 
48
  # Additional libraries for querying
49
  from FlagEmbedding import FlagModel
50
 
51
  # Global variables and dataset loading
52
  global dataset_name
53
+ st.set_page_config(layout="wide")
54
+
55
+ dataset_name = "somewheresystems/dataclysm-arxiv"
56
+
57
+ set_page_container_style(
58
+ max_width = 1600, max_width_100_percent = True,
59
+ padding_top = 0, padding_right = 10, padding_left = 5, padding_bottom = 10
60
+ )
61
  st.session_state.dataclysm_arxiv = load_dataset(dataset_name, split="train")
62
  total_samples = len(st.session_state.dataclysm_arxiv)
63
 
 
125
 
126
  def perform_clustering(df, tsne_results):
127
  start_time = time.time()
128
+ # Perform DBSCAN clustering
129
+ logging.info('Performing HDBSCAN clustering...')
130
  # Step 3: Visualization with Plotly
131
+ # Normalize the t-SNE results between 0 and 1
132
+ df['tsne-3d-one'] = (tsne_results[:,0] - tsne_results[:,0].min()) / (tsne_results[:,0].max() - tsne_results[:,0].min())
133
+ df['tsne-3d-two'] = (tsne_results[:,1] - tsne_results[:,1].min()) / (tsne_results[:,1].max() - tsne_results[:,1].min())
134
+ df['tsne-3d-three'] = (tsne_results[:,2] - tsne_results[:,2].min()) / (tsne_results[:,2].max() - tsne_results[:,2].min())
135
+
136
+ # Perform DBSCAN clustering
137
+ hdbscan = HDBSCAN(min_cluster_size=10, min_samples=50)
138
+ cluster_labels = hdbscan.fit_predict(df[['tsne-3d-one', 'tsne-3d-two', 'tsne-3d-three']])
139
+ df['cluster'] = cluster_labels
140
  end_time = time.time() # End timing
141
+ st.sidebar.text(f'HDBSCAN clustering completed in {end_time - start_time:.3f} seconds')
142
  return df
143
 
144
+ def update_camera_position(fig, df, df_query, result_id, K=10):
145
+ # Focus the camera on the closest result
146
+ top_K_ids = df_query.sort_values(by='proximity', ascending=True).head(K)['id'].tolist()
147
+ top_K_proximity = df_query['proximity'].tolist()
148
+ top_results = df[df['id'].isin(top_K_ids)]
149
+ camera_focus = dict(
150
+ eye=dict(x=top_results.iloc[0]['tsne-3d-one']*0.1, y=top_results.iloc[0]['tsne-3d-two']*0.1, z=top_results.iloc[0]['tsne-3d-three']*0.1)
151
+ )
152
+ # Normalize the proximity values to range between 1 and 10
153
+ normalized_proximity = [10 - (10 * (prox - min(top_K_proximity)) / (max(top_K_proximity) - min(top_K_proximity))) for prox in top_K_proximity]
154
+ # Create a dictionary mapping id to normalized proximity
155
+ id_to_proximity = dict(zip(top_K_ids, normalized_proximity))
156
+ # Set marker sizes based on proximity for top K ids, all other points stay the same
157
+ marker_sizes = [id_to_proximity[id] if id in top_K_ids else 1 for id in df['id']]
158
+ # Store the original colors in a separate column
159
+ df['color'] = df['cluster']
160
+
161
+ fig = go.Figure(data=[go.Scatter3d(
162
+ x=df['tsne-3d-one'],
163
+ y=df['tsne-3d-two'],
164
+ z=df['tsne-3d-three'],
165
+ mode='markers',
166
+ marker=dict(size=marker_sizes, color=df['color'], colorscale='Viridis', opacity=0.8, line_width=0),
167
+ hovertext=df['hovertext'],
168
+ hoverinfo='text',
169
+ )])
170
+ # Set grid opacity to 10%
171
+ fig.update_layout(scene = dict(xaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
172
+ yaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
173
+ zaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)')))
174
+
175
+ # Add lines stemming from the top result to all other points in the top K
176
+ for i in range(0, K): # there are K-1 lines from the top result to the other K-1 points
177
+ fig.add_trace(go.Scatter3d(
178
+ x=[top_results.iloc[0]['tsne-3d-one'], top_results.iloc[i]['tsne-3d-one']],
179
+ y=[top_results.iloc[0]['tsne-3d-two'], top_results.iloc[i]['tsne-3d-two']],
180
+ z=[top_results.iloc[0]['tsne-3d-three'], top_results.iloc[i]['tsne-3d-three']],
181
+ mode='lines',
182
+ line=dict(color='white',width=0.4), # Set line opacity to 50%
183
+ showlegend=False,
184
+ hoverinfo='none',
185
+ ))
186
+ fig.update_layout(plot_bgcolor='rgba(0,0,0,0)',
187
+ paper_bgcolor='rgba(0,0,0,0)',
188
+ scene_camera=camera_focus)
189
+ return fig
190
+
191
  def main():
192
  # Custom CSS
193
  custom_css = """
 
203
  color: #F8F8F8; /* Set the font color to F8F8F8 */
204
  }
205
  /* Add your CSS styles here */
206
+ .stPlotlyChart {
207
+ width: 100%;
208
+ height: 100%;
209
+ /* Other styles... */
210
+ }
211
  h1 {
212
  text-align: center;
213
  }
214
  h2,h3,h4 {
215
  text-align: justify;
216
+ font-size: 8px;
217
+ }
218
+ st-emotion-cache-1wmy9hl {
219
+ font-size: 8px;
220
  }
221
  body {
222
+ color: #fff;
223
+ background-color: #202020;
224
  }
225
+
226
  .stSlider .css-1cpxqw2 {
227
  background: #202020;
228
+ color: #fd5137;
229
+ }
230
+ .stSlider .text {
231
+ background: #202020;
232
+ color: #fd5137;
233
  }
234
  .stButton > button {
235
  background-color: #202020;
236
+ width: 60%;
237
+ margin-left: auto;
238
+ margin-right: auto;
239
+ display: block;
240
  padding: 10px 24px;
 
241
  font-size: 16px;
242
  font-weight: bold;
243
+ border: 1px solid #f8f8f8;
244
+ }
245
+ .stButton > button:hover {
246
+ color: #Fd5137
247
+ border: 1px solid #fd5137;
248
+ }
249
+ .stButton > button:active {
250
+ color: #F8F8F8;
251
+ border: 1px solid #fd5137;
252
+ background-color: #fd5137;
253
  }
254
  .reportview-container .main .block-container {
255
+ padding: 0;
256
  background-color: #202020;
257
+ width: 100%; /* Make the plotly graph take up full width */
258
+ }
259
+ .sidebar .sidebar-content {
260
+ background-image: linear-gradient(#202020,#202020);
261
+ color: white;
262
+ size: 0.2em; /* Make the text in the sidebar smaller */
263
+ padding: 0;
264
+ }
265
+ .reportview-container .main .block-container {
266
+ background-color: #000000;
267
+ }
268
+ .stText {
269
+ padding: 0;
270
+ }
271
+ /* Set the main background color to #202020 */
272
+ .appview-container {
273
+ background-color: #000000;
274
+ padding: 0;
275
+ }
276
+ .stVerticalBlockBorderWrapper{
277
+ padding: 0;
278
+ margin-left: 0px;
279
+ }
280
+ .st-emotion-cache-1cypcdb {
281
+ background-color: #202020;
282
+ background-image: none;
283
+ color: #000000;
284
+ padding: 0;
285
+ }
286
+ .stPlotlyChart {
287
+ background-color: #000000;
288
+ background-image: none;
289
+ color: #000000;
290
+ padding: 0;
291
+ }
292
+ .reportview-container .css-1lcbmhc .css-1outpf7 {
293
+ padding-top: 35px;
294
+ }
295
+ .reportview-container .main .block-container {
296
+ max-width: 100%;
297
+ padding-top: 0rem;
298
+ padding-right: 0rem;
299
+ padding-left: 0rem;
300
+ padding-bottom: 10rem;
301
+ }
302
+ .reportview-container .main {
303
+ color: white;
304
+ background-color: black;
305
+ }
306
+ .stHeader {
307
+ color: black;
308
+ background-color: black;
309
  }
310
+
311
  </style>
312
  """
313
 
314
  # Inject custom CSS with markdown
315
  st.markdown(custom_css, unsafe_allow_html=True)
316
+ st.sidebar.title('Spatial Search Engine')
317
  st.sidebar.markdown(
318
+ '<a href="http://dataclysm.xyz" target="_blank" style="display: flex; justify-content: center; padding: 10px;">dataclysm.xyz <img src="https://www.somewhere.systems/S2-white-logo.png" style="width: 8px; height: 8px;"></a>',
319
  unsafe_allow_html=True
320
  )
 
321
 
322
  # Check if data needs to be loaded
323
  if 'data_loaded' not in st.session_state or not st.session_state.data_loaded:
324
  # User input for number of samples
325
+ num_samples = st.sidebar.slider('Select number of samples', 1000, int(round(total_samples/10)), 1000)
326
 
327
  if st.sidebar.button('Initialize'):
328
  st.sidebar.text('Initializing data pipeline...')
 
341
  print(f"FAISS index for {column_name} added.")
342
 
343
  return dataset
 
 
344
 
345
  # Load data and perform t-SNE and clustering
346
  df, embeddings = load_data(num_samples)
 
377
  marker=dict(
378
  size=1,
379
  color=df['cluster'],
380
+ colorscale='Jet',
381
+ opacity=0.75
382
  )
383
  )])
384
+ # Set grid opacity to 10%
385
+ fig.update_layout(scene = dict(xaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
386
+ yaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
387
+ zaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)')))
388
 
389
  fig.update_layout(
390
+ plot_bgcolor='rgba(0,0,0,0)',
391
+ paper_bgcolor='rgba(0,0,0,0)',
392
  height=800,
393
  margin=dict(l=0, r=0, b=0, t=0),
394
+ scene_camera=dict(eye=dict(x=0.1, y=0.1, z=0.1))
 
 
 
 
 
395
  )
396
  st.session_state.fig = fig
397
 
 
404
  if 'df' in st.session_state:
405
  # Sidebar for querying
406
  with st.sidebar:
407
+ st.sidebar.markdown("# Detailed View")
408
+ selected_index = st.sidebar.selectbox("Select Key", st.session_state.df.id)
409
+
410
+ # Display metadata for the selected article
411
+ selected_row = st.session_state.df[st.session_state.df['id'] == selected_index].iloc[0]
412
+ st.markdown(f"### Title\n{selected_row['title']}", unsafe_allow_html=True)
413
+ st.markdown(f"### Abstract\n{selected_row['abstract']}", unsafe_allow_html=True)
414
+ st.markdown(f"[Read the full paper](https://arxiv.org/abs/{selected_row['id']})", unsafe_allow_html=True)
415
+ st.markdown(f"[Download PDF](https://arxiv.org/pdf/{selected_row['id']})", unsafe_allow_html=True)
416
+
417
+ st.sidebar.markdown("### Find Similar in Latent Space")
418
+ query = st.text_input("", value=selected_row['title'])
419
+ top_k = st.slider("top k", 1, 100, 10)
420
  if st.button("Search"):
421
  # Define the model
422
  print("Initializing model...")
 
427
 
428
  query_embedding = model.encode([query])
429
  # Retrieve examples by title similarity (or abstract, depending on your preference)
430
+ scores_title, retrieved_examples_title = st.session_state.dataclysm_title_indexed.get_nearest_examples('title_embedding', query_embedding, k=top_k)
431
  df_query = pd.DataFrame(retrieved_examples_title)
432
  df_query['proximity'] = scores_title
433
  df_query = df_query.sort_values(by='proximity', ascending=True)
 
436
  # Fix the <a href link> to display properly
437
  df_query['URL'] = df_query['id'].apply(lambda x: f'<a href="https://arxiv.org/abs/{x}" target="_blank">Link</a>')
438
  st.sidebar.markdown(df_query[['title', 'proximity', 'id']].to_html(escape=False), unsafe_allow_html=True)
439
+ # Get the ID of the top search result
440
+ top_result_id = df_query.iloc[0]['id']
441
 
442
+ # Update the camera position and appearance of points
443
+ updated_fig = update_camera_position(st.session_state.fig, st.session_state.df, df_query, top_result_id,top_k)
 
 
 
 
 
 
 
 
 
444
 
445
+ # Update the figure in the session state and redraw the plot in the sidebar
446
+ st.session_state.fig = updated_fig
447
+ # Display the plot if data is loaded
448
+ if 'data_loaded' in st.session_state and st.session_state.data_loaded:
449
+ st.plotly_chart(st.session_state.fig, use_container_width=True)
450
 
451
+ if __name__ == "__main__":
452
+ main()