somewheresy commited on
Commit
a7193d8
1 Parent(s): 4fa8c7b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -225
app.py CHANGED
@@ -8,56 +8,14 @@ from sklearn.cluster import KMeans
8
  import plotly.graph_objects as go
9
  import time
10
  import logging
11
- from sklearn.cluster import HDBSCAN
12
-
13
-
14
- BACKGROUND_COLOR = 'black'
15
- COLOR = 'white'
16
-
17
- def set_page_container_style(
18
- max_width: int = 10000, max_width_100_percent: bool = False,
19
- padding_top: int = 1, padding_right: int = 10, padding_left: int = 1, padding_bottom: int = 10,
20
- color: str = COLOR, background_color: str = BACKGROUND_COLOR,
21
- ):
22
- if max_width_100_percent:
23
- max_width_str = f'max-width: 100%;'
24
- else:
25
- max_width_str = f'max-width: {max_width}px;'
26
- st.markdown(
27
- f'''
28
- <style>
29
- .reportview-container .css-1lcbmhc .css-1outpf7 {{
30
- padding-top: 35px;
31
- }}
32
- .reportview-container .main .block-container {{
33
- {max_width_str}
34
- padding-top: {padding_top}rem;
35
- padding-right: {padding_right}rem;
36
- padding-left: {padding_left}rem;
37
- padding-bottom: {padding_bottom}rem;
38
- }}
39
- .reportview-container .main {{
40
- color: {color};
41
- background-color: {background_color};
42
- }}
43
- </style>
44
- ''',
45
- unsafe_allow_html=True,
46
- )
47
 
48
  # Additional libraries for querying
49
  from FlagEmbedding import FlagModel
50
 
51
  # Global variables and dataset loading
52
  global dataset_name
53
- st.set_page_config(layout="wide")
54
-
55
- dataset_name = "somewheresystems/dataclysm-arxiv"
56
-
57
- set_page_container_style(
58
- max_width = 1600, max_width_100_percent = True,
59
- padding_top = 0, padding_right = 10, padding_left = 5, padding_bottom = 10
60
- )
61
  st.session_state.dataclysm_arxiv = load_dataset(dataset_name, split="train")
62
  total_samples = len(st.session_state.dataclysm_arxiv)
63
 
@@ -125,69 +83,20 @@ def perform_tsne(embeddings):
125
 
126
  def perform_clustering(df, tsne_results):
127
  start_time = time.time()
128
- # Perform DBSCAN clustering
129
- logging.info('Performing HDBSCAN clustering...')
130
  # Step 3: Visualization with Plotly
131
- # Normalize the t-SNE results between 0 and 1
132
- df['tsne-3d-one'] = (tsne_results[:,0] - tsne_results[:,0].min()) / (tsne_results[:,0].max() - tsne_results[:,0].min())
133
- df['tsne-3d-two'] = (tsne_results[:,1] - tsne_results[:,1].min()) / (tsne_results[:,1].max() - tsne_results[:,1].min())
134
- df['tsne-3d-three'] = (tsne_results[:,2] - tsne_results[:,2].min()) / (tsne_results[:,2].max() - tsne_results[:,2].min())
135
-
136
- # Perform DBSCAN clustering
137
- hdbscan = HDBSCAN(min_cluster_size=10, min_samples=50)
138
- cluster_labels = hdbscan.fit_predict(df[['tsne-3d-one', 'tsne-3d-two', 'tsne-3d-three']])
139
- df['cluster'] = cluster_labels
140
  end_time = time.time() # End timing
141
- st.sidebar.text(f'HDBSCAN clustering completed in {end_time - start_time:.3f} seconds')
142
  return df
143
 
144
- def update_camera_position(fig, df, df_query, result_id, K=10):
145
- # Focus the camera on the closest result
146
- top_K_ids = df_query.sort_values(by='proximity', ascending=True).head(K)['id'].tolist()
147
- top_K_proximity = df_query['proximity'].tolist()
148
- top_results = df[df['id'].isin(top_K_ids)]
149
- camera_focus = dict(
150
- eye=dict(x=top_results.iloc[0]['tsne-3d-one']*0.1, y=top_results.iloc[0]['tsne-3d-two']*0.1, z=top_results.iloc[0]['tsne-3d-three']*0.1)
151
- )
152
- # Normalize the proximity values to range between 1 and 10
153
- normalized_proximity = [10 - (10 * (prox - min(top_K_proximity)) / (max(top_K_proximity) - min(top_K_proximity))) for prox in top_K_proximity]
154
- # Create a dictionary mapping id to normalized proximity
155
- id_to_proximity = dict(zip(top_K_ids, normalized_proximity))
156
- # Set marker sizes based on proximity for top K ids, all other points stay the same
157
- marker_sizes = [id_to_proximity[id] if id in top_K_ids else 1 for id in df['id']]
158
- # Store the original colors in a separate column
159
- df['color'] = df['cluster']
160
-
161
- fig = go.Figure(data=[go.Scatter3d(
162
- x=df['tsne-3d-one'],
163
- y=df['tsne-3d-two'],
164
- z=df['tsne-3d-three'],
165
- mode='markers',
166
- marker=dict(size=marker_sizes, color=df['color'], colorscale='Viridis', opacity=0.8, line_width=0),
167
- hovertext=df['hovertext'],
168
- hoverinfo='text',
169
- )])
170
- # Set grid opacity to 10%
171
- fig.update_layout(scene = dict(xaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
172
- yaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
173
- zaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)')))
174
-
175
- # Add lines stemming from the top result to all other points in the top K
176
- for i in range(0, K): # there are K-1 lines from the top result to the other K-1 points
177
- fig.add_trace(go.Scatter3d(
178
- x=[top_results.iloc[0]['tsne-3d-one'], top_results.iloc[i]['tsne-3d-one']],
179
- y=[top_results.iloc[0]['tsne-3d-two'], top_results.iloc[i]['tsne-3d-two']],
180
- z=[top_results.iloc[0]['tsne-3d-three'], top_results.iloc[i]['tsne-3d-three']],
181
- mode='lines',
182
- line=dict(color='white',width=0.4), # Set line opacity to 50%
183
- showlegend=False,
184
- hoverinfo='none',
185
- ))
186
- fig.update_layout(plot_bgcolor='rgba(0,0,0,0)',
187
- paper_bgcolor='rgba(0,0,0,0)',
188
- scene_camera=camera_focus)
189
- return fig
190
-
191
  def main():
192
  # Custom CSS
193
  custom_css = """
@@ -203,126 +112,47 @@ def main():
203
  color: #F8F8F8; /* Set the font color to F8F8F8 */
204
  }
205
  /* Add your CSS styles here */
206
- .stPlotlyChart {
207
- width: 100%;
208
- height: 100%;
209
- /* Other styles... */
210
- }
211
  h1 {
212
  text-align: center;
213
  }
214
  h2,h3,h4 {
215
  text-align: justify;
216
- font-size: 8px;
217
- }
218
- st-emotion-cache-1wmy9hl {
219
- font-size: 8px;
220
  }
221
  body {
222
- color: #fff;
223
- background-color: #202020;
224
  }
225
-
226
  .stSlider .css-1cpxqw2 {
227
  background: #202020;
228
- color: #fd5137;
229
- }
230
- .stSlider .text {
231
- background: #202020;
232
- color: #fd5137;
233
  }
234
  .stButton > button {
235
  background-color: #202020;
236
- width: 60%;
237
- margin-left: auto;
238
- margin-right: auto;
239
- display: block;
240
  padding: 10px 24px;
 
241
  font-size: 16px;
242
  font-weight: bold;
243
- border: 1px solid #f8f8f8;
244
- }
245
- .stButton > button:hover {
246
- color: #Fd5137
247
- border: 1px solid #fd5137;
248
- }
249
- .stButton > button:active {
250
- color: #F8F8F8;
251
- border: 1px solid #fd5137;
252
- background-color: #fd5137;
253
  }
254
  .reportview-container .main .block-container {
255
- padding: 0;
256
  background-color: #202020;
257
- width: 100%; /* Make the plotly graph take up full width */
258
- }
259
- .sidebar .sidebar-content {
260
- background-image: linear-gradient(#202020,#202020);
261
- color: white;
262
- size: 0.2em; /* Make the text in the sidebar smaller */
263
- padding: 0;
264
- }
265
- .reportview-container .main .block-container {
266
- background-color: #000000;
267
- }
268
- .stText {
269
- padding: 0;
270
- }
271
- /* Set the main background color to #202020 */
272
- .appview-container {
273
- background-color: #000000;
274
- padding: 0;
275
- }
276
- .stVerticalBlockBorderWrapper{
277
- padding: 0;
278
- margin-left: 0px;
279
- }
280
- .st-emotion-cache-1cypcdb {
281
- background-color: #202020;
282
- background-image: none;
283
- color: #000000;
284
- padding: 0;
285
- }
286
- .stPlotlyChart {
287
- background-color: #000000;
288
- background-image: none;
289
- color: #000000;
290
- padding: 0;
291
- }
292
- .reportview-container .css-1lcbmhc .css-1outpf7 {
293
- padding-top: 35px;
294
- }
295
- .reportview-container .main .block-container {
296
- max-width: 100%;
297
- padding-top: 0rem;
298
- padding-right: 0rem;
299
- padding-left: 0rem;
300
- padding-bottom: 10rem;
301
- }
302
- .reportview-container .main {
303
- color: white;
304
- background-color: black;
305
- }
306
- .stHeader {
307
- color: black;
308
- background-color: black;
309
  }
310
-
311
  </style>
312
  """
313
 
314
  # Inject custom CSS with markdown
315
  st.markdown(custom_css, unsafe_allow_html=True)
316
- st.sidebar.title('Spatial Search Engine')
317
  st.sidebar.markdown(
318
- '<a href="http://dataclysm.xyz" target="_blank" style="display: flex; justify-content: center; padding: 10px;">dataclysm.xyz <img src="https://www.somewhere.systems/S2-white-logo.png" style="width: 8px; height: 8px;"></a>',
319
  unsafe_allow_html=True
320
  )
 
321
 
322
  # Check if data needs to be loaded
323
  if 'data_loaded' not in st.session_state or not st.session_state.data_loaded:
324
  # User input for number of samples
325
- num_samples = st.sidebar.slider('Select number of samples', 1000, int(round(total_samples/10)), 1000)
326
 
327
  if st.sidebar.button('Initialize'):
328
  st.sidebar.text('Initializing data pipeline...')
@@ -341,6 +171,8 @@ def main():
341
  print(f"FAISS index for {column_name} added.")
342
 
343
  return dataset
 
 
344
 
345
  # Load data and perform t-SNE and clustering
346
  df, embeddings = load_data(num_samples)
@@ -377,21 +209,21 @@ def main():
377
  marker=dict(
378
  size=1,
379
  color=df['cluster'],
380
- colorscale='Jet',
381
- opacity=0.75
382
  )
383
  )])
384
- # Set grid opacity to 10%
385
- fig.update_layout(scene = dict(xaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
386
- yaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)'),
387
- zaxis = dict(gridcolor='rgba(128, 128, 128, 0.1)', color='rgba(128, 128, 128, 0.1)')))
388
 
389
  fig.update_layout(
390
- plot_bgcolor='rgba(0,0,0,0)',
391
- paper_bgcolor='rgba(0,0,0,0)',
392
  height=800,
393
  margin=dict(l=0, r=0, b=0, t=0),
394
- scene_camera=dict(eye=dict(x=0.1, y=0.1, z=0.1))
 
 
 
 
 
395
  )
396
  st.session_state.fig = fig
397
 
@@ -404,19 +236,8 @@ def main():
404
  if 'df' in st.session_state:
405
  # Sidebar for querying
406
  with st.sidebar:
407
- st.sidebar.markdown("# Detailed View")
408
- selected_index = st.sidebar.selectbox("Select Key", st.session_state.df.id)
409
-
410
- # Display metadata for the selected article
411
- selected_row = st.session_state.df[st.session_state.df['id'] == selected_index].iloc[0]
412
- st.markdown(f"### Title\n{selected_row['title']}", unsafe_allow_html=True)
413
- st.markdown(f"### Abstract\n{selected_row['abstract']}", unsafe_allow_html=True)
414
- st.markdown(f"[Read the full paper](https://arxiv.org/abs/{selected_row['id']})", unsafe_allow_html=True)
415
- st.markdown(f"[Download PDF](https://arxiv.org/pdf/{selected_row['id']})", unsafe_allow_html=True)
416
-
417
- st.sidebar.markdown("### Find Similar in Latent Space")
418
- query = st.text_input("", value=selected_row['title'])
419
- top_k = st.slider("top k", 1, 100, 10)
420
  if st.button("Search"):
421
  # Define the model
422
  print("Initializing model...")
@@ -427,7 +248,7 @@ def main():
427
 
428
  query_embedding = model.encode([query])
429
  # Retrieve examples by title similarity (or abstract, depending on your preference)
430
- scores_title, retrieved_examples_title = st.session_state.dataclysm_title_indexed.get_nearest_examples('title_embedding', query_embedding, k=top_k)
431
  df_query = pd.DataFrame(retrieved_examples_title)
432
  df_query['proximity'] = scores_title
433
  df_query = df_query.sort_values(by='proximity', ascending=True)
@@ -436,17 +257,19 @@ def main():
436
  # Fix the <a href link> to display properly
437
  df_query['URL'] = df_query['id'].apply(lambda x: f'<a href="https://arxiv.org/abs/{x}" target="_blank">Link</a>')
438
  st.sidebar.markdown(df_query[['title', 'proximity', 'id']].to_html(escape=False), unsafe_allow_html=True)
439
- # Get the ID of the top search result
440
- top_result_id = df_query.iloc[0]['id']
441
-
442
- # Update the camera position and appearance of points
443
- updated_fig = update_camera_position(st.session_state.fig, st.session_state.df, df_query, top_result_id,top_k)
444
 
445
- # Update the figure in the session state and redraw the plot in the sidebar
446
- st.session_state.fig = updated_fig
447
- # Display the plot if data is loaded
448
- if 'data_loaded' in st.session_state and st.session_state.data_loaded:
449
- st.plotly_chart(st.session_state.fig, use_container_width=True)
 
 
 
450
 
451
  if __name__ == "__main__":
452
- main()
 
 
 
8
  import plotly.graph_objects as go
9
  import time
10
  import logging
11
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Additional libraries for querying
14
  from FlagEmbedding import FlagModel
15
 
16
  # Global variables and dataset loading
17
  global dataset_name
18
+ dataset_name = 'somewheresystems/dataclysm-arxiv'
 
 
 
 
 
 
 
19
  st.session_state.dataclysm_arxiv = load_dataset(dataset_name, split="train")
20
  total_samples = len(st.session_state.dataclysm_arxiv)
21
 
 
83
 
84
  def perform_clustering(df, tsne_results):
85
  start_time = time.time()
86
+ # Perform KMeans clustering
87
+ logging.info('Performing k-means clustering...')
88
  # Step 3: Visualization with Plotly
89
+ df['tsne-3d-one'] = tsne_results[:,0]
90
+ df['tsne-3d-two'] = tsne_results[:,1]
91
+ df['tsne-3d-three'] = tsne_results[:,2]
92
+
93
+ # Perform KMeans clustering
94
+ kmeans = KMeans(n_clusters=16) # Change the number of clusters as needed
95
+ df['cluster'] = kmeans.fit_predict(df[['tsne-3d-one', 'tsne-3d-two', 'tsne-3d-three']])
 
 
96
  end_time = time.time() # End timing
97
+ st.sidebar.text(f'k-means clustering completed in {end_time - start_time:.3f} seconds')
98
  return df
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def main():
101
  # Custom CSS
102
  custom_css = """
 
112
  color: #F8F8F8; /* Set the font color to F8F8F8 */
113
  }
114
  /* Add your CSS styles here */
 
 
 
 
 
115
  h1 {
116
  text-align: center;
117
  }
118
  h2,h3,h4 {
119
  text-align: justify;
120
+ font-size: 8px
 
 
 
121
  }
122
  body {
123
+ text-align: justify;
 
124
  }
 
125
  .stSlider .css-1cpxqw2 {
126
  background: #202020;
 
 
 
 
 
127
  }
128
  .stButton > button {
129
  background-color: #202020;
130
+ width: 100%;
131
+ border: none;
 
 
132
  padding: 10px 24px;
133
+ border-radius: 5px;
134
  font-size: 16px;
135
  font-weight: bold;
 
 
 
 
 
 
 
 
 
 
136
  }
137
  .reportview-container .main .block-container {
138
+ padding: 2rem;
139
  background-color: #202020;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  }
 
141
  </style>
142
  """
143
 
144
  # Inject custom CSS with markdown
145
  st.markdown(custom_css, unsafe_allow_html=True)
 
146
  st.sidebar.markdown(
147
+ f'<img src="https://www.somewhere.systems/S2-white-logo.png" style="float: bottom-left; width: 32px; height: 32px; opacity: 1.0; animation: fadein 2s;">',
148
  unsafe_allow_html=True
149
  )
150
+ st.sidebar.title('Spatial Search Engine')
151
 
152
  # Check if data needs to be loaded
153
  if 'data_loaded' not in st.session_state or not st.session_state.data_loaded:
154
  # User input for number of samples
155
+ num_samples = st.sidebar.slider('Select number of samples', 1000, total_samples, 1000)
156
 
157
  if st.sidebar.button('Initialize'):
158
  st.sidebar.text('Initializing data pipeline...')
 
171
  print(f"FAISS index for {column_name} added.")
172
 
173
  return dataset
174
+
175
+
176
 
177
  # Load data and perform t-SNE and clustering
178
  df, embeddings = load_data(num_samples)
 
209
  marker=dict(
210
  size=1,
211
  color=df['cluster'],
212
+ colorscale='Viridis',
213
+ opacity=0.8
214
  )
215
  )])
 
 
 
 
216
 
217
  fig.update_layout(
218
+ plot_bgcolor='#202020',
 
219
  height=800,
220
  margin=dict(l=0, r=0, b=0, t=0),
221
+ scene=dict(
222
+ xaxis=dict(showbackground=True, backgroundcolor="#000000"),
223
+ yaxis=dict(showbackground=True, backgroundcolor="#000000"),
224
+ zaxis=dict(showbackground=True, backgroundcolor="#000000"),
225
+ ),
226
+ scene_camera=dict(eye=dict(x=0.001, y=0.001, z=0.001))
227
  )
228
  st.session_state.fig = fig
229
 
 
236
  if 'df' in st.session_state:
237
  # Sidebar for querying
238
  with st.sidebar:
239
+ st.sidebar.markdown("### Query Embeddings")
240
+ query = st.text_input("Enter your query:")
 
 
 
 
 
 
 
 
 
 
 
241
  if st.button("Search"):
242
  # Define the model
243
  print("Initializing model...")
 
248
 
249
  query_embedding = model.encode([query])
250
  # Retrieve examples by title similarity (or abstract, depending on your preference)
251
+ scores_title, retrieved_examples_title = st.session_state.dataclysm_title_indexed.get_nearest_examples('title_embedding', query_embedding, k=10)
252
  df_query = pd.DataFrame(retrieved_examples_title)
253
  df_query['proximity'] = scores_title
254
  df_query = df_query.sort_values(by='proximity', ascending=True)
 
257
  # Fix the <a href link> to display properly
258
  df_query['URL'] = df_query['id'].apply(lambda x: f'<a href="https://arxiv.org/abs/{x}" target="_blank">Link</a>')
259
  st.sidebar.markdown(df_query[['title', 'proximity', 'id']].to_html(escape=False), unsafe_allow_html=True)
260
+ st.sidebar.markdown("# Detailed View")
261
+ selected_index = st.sidebar.selectbox("Select Key", st.session_state.df.id)
 
 
 
262
 
263
+ # Display metadata for the selected article
264
+ selected_row = st.session_state.df[st.session_state.df['id'] == selected_index].iloc[0]
265
+ st.markdown(f"### Title\n{selected_row['title']}", unsafe_allow_html=True)
266
+ st.markdown(f"### Abstract\n{selected_row['abstract']}", unsafe_allow_html=True)
267
+ st.markdown(f"[Read the full paper](https://arxiv.org/abs/{selected_row['id']})", unsafe_allow_html=True)
268
+ st.markdown(f"[Download PDF](https://arxiv.org/pdf/{selected_row['id']})", unsafe_allow_html=True)
269
+
270
+
271
 
272
  if __name__ == "__main__":
273
+ main()
274
+
275
+