|
|
|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.manifold import TSNE |
|
from datasets import load_dataset, Dataset |
|
from sklearn.cluster import KMeans |
|
import plotly.graph_objects as go |
|
import time |
|
import logging |
|
|
|
|
|
|
|
from FlagEmbedding import FlagModel |
|
|
|
|
|
global dataset_name |
|
dataset_name = 'somewheresystems/dataclysm-arxiv' |
|
st.session_state.dataclysm_arxiv = load_dataset(dataset_name, split="train") |
|
total_samples = len(st.session_state.dataclysm_arxiv) |
|
|
|
logging.basicConfig(filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s', level=logging.INFO) |
|
|
|
|
|
model = FlagModel('BAAI/bge-small-en-v1.5', query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", use_fp16=True) |
|
|
|
|
|
def load_data(num_samples): |
|
start_time = time.time() |
|
dataset_name = 'somewheresystems/dataclysm-arxiv' |
|
|
|
logging.info(f'Loading dataset...') |
|
dataset = load_dataset(dataset_name) |
|
total_samples = len(dataset['train']) |
|
|
|
logging.info('Converting to pandas dataframe...') |
|
|
|
df = dataset['train'].to_pandas() |
|
|
|
|
|
num_samples = min(num_samples, total_samples) |
|
st.sidebar.text(f'Number of samples: {num_samples} ({num_samples / total_samples:.2%} of total)') |
|
|
|
|
|
df = df.sample(n=num_samples) |
|
|
|
|
|
embeddings = df['title_embedding'].tolist() |
|
print("embeddings length: " + str(len(embeddings))) |
|
|
|
|
|
embeddings = np.array(embeddings, dtype=object) |
|
end_time = time.time() |
|
st.sidebar.text(f'Data loading completed in {end_time - start_time:.3f} seconds') |
|
return df, embeddings |
|
|
|
def perform_tsne(embeddings): |
|
start_time = time.time() |
|
logging.info('Performing t-SNE...') |
|
|
|
n_samples = len(embeddings) |
|
perplexity = min(30, n_samples - 1) if n_samples > 1 else 1 |
|
|
|
|
|
if len(set([len(embed) for embed in embeddings])) > 1: |
|
raise ValueError("All embeddings should have the same length") |
|
|
|
|
|
tsne = TSNE(n_components=3, perplexity=perplexity, n_iter=300) |
|
|
|
|
|
progress_text = st.empty() |
|
progress_text.text("t-SNE in progress...") |
|
|
|
tsne_results = tsne.fit_transform(np.vstack(embeddings.tolist())) |
|
|
|
|
|
progress_text.text(f"t-SNE completed. Processed {n_samples} samples with perplexity {perplexity}.") |
|
end_time = time.time() |
|
st.sidebar.text(f't-SNE completed in {end_time - start_time:.3f} seconds') |
|
return tsne_results |
|
|
|
|
|
def perform_clustering(df, tsne_results): |
|
start_time = time.time() |
|
|
|
logging.info('Performing k-means clustering...') |
|
|
|
df['tsne-3d-one'] = tsne_results[:,0] |
|
df['tsne-3d-two'] = tsne_results[:,1] |
|
df['tsne-3d-three'] = tsne_results[:,2] |
|
|
|
|
|
kmeans = KMeans(n_clusters=16) |
|
df['cluster'] = kmeans.fit_predict(df[['tsne-3d-one', 'tsne-3d-two', 'tsne-3d-three']]) |
|
end_time = time.time() |
|
st.sidebar.text(f'k-means clustering completed in {end_time - start_time:.3f} seconds') |
|
return df |
|
|
|
def main(): |
|
|
|
custom_css = """ |
|
<style> |
|
/* Define the font */ |
|
@font-face { |
|
font-family: 'F'; |
|
src: url('https://fonts.googleapis.com/css2?family=Martian+Mono&display=swap') format('truetype'); |
|
} |
|
/* Apply the font to all elements */ |
|
* { |
|
font-family: 'F', sans-serif !important; |
|
color: #F8F8F8; /* Set the font color to F8F8F8 */ |
|
} |
|
/* Add your CSS styles here */ |
|
h1 { |
|
text-align: center; |
|
} |
|
h2,h3,h4 { |
|
text-align: justify; |
|
font-size: 8px |
|
} |
|
body { |
|
text-align: justify; |
|
} |
|
.stSlider .css-1cpxqw2 { |
|
background: #202020; |
|
} |
|
.stButton > button { |
|
background-color: #202020; |
|
width: 100%; |
|
border: none; |
|
padding: 10px 24px; |
|
border-radius: 5px; |
|
font-size: 16px; |
|
font-weight: bold; |
|
} |
|
.reportview-container .main .block-container { |
|
padding: 2rem; |
|
background-color: #202020; |
|
} |
|
</style> |
|
""" |
|
|
|
|
|
st.markdown(custom_css, unsafe_allow_html=True) |
|
st.sidebar.markdown( |
|
f'<img src="https://www.somewhere.systems/S2-white-logo.png" style="float: bottom-left; width: 32px; height: 32px; opacity: 1.0; animation: fadein 2s;">', |
|
unsafe_allow_html=True |
|
) |
|
st.sidebar.title('Spatial Search Engine') |
|
|
|
|
|
if 'data_loaded' not in st.session_state or not st.session_state.data_loaded: |
|
|
|
num_samples = st.sidebar.slider('Select number of samples', 1000, total_samples, 1000) |
|
|
|
if st.sidebar.button('Initialize'): |
|
st.sidebar.text('Initializing data pipeline...') |
|
|
|
|
|
def reshape_and_add_faiss_index(dataset, column_name): |
|
|
|
|
|
|
|
|
|
print(f"Flattening {column_name} and adding FAISS index...") |
|
|
|
dataset[column_name] = dataset[column_name].apply(lambda x: np.array(x).flatten()) |
|
|
|
dataset = Dataset.from_pandas(dataset).add_faiss_index(column=column_name) |
|
print(f"FAISS index for {column_name} added.") |
|
|
|
return dataset |
|
|
|
|
|
|
|
|
|
df, embeddings = load_data(num_samples) |
|
|
|
|
|
|
|
embeddings_list = [embedding.flatten().tolist() for embedding in embeddings] |
|
df['title_embedding'] = embeddings_list |
|
|
|
print(df.head()) |
|
|
|
st.session_state.dataclysm_title_indexed = reshape_and_add_faiss_index(df, 'title_embedding') |
|
tsne_results = perform_tsne(embeddings) |
|
df = perform_clustering(df, tsne_results) |
|
|
|
st.session_state.df = df |
|
st.session_state.tsne_results = tsne_results |
|
st.session_state.data_loaded = True |
|
|
|
|
|
df['hovertext'] = df.apply( |
|
lambda row: f"<b>Title:</b> {row['title']}<br><b>arXiv ID:</b> {row['id']}<br><b>Key:</b> {row.name}", axis=1 |
|
) |
|
st.sidebar.text("Datasets loaded, titles indexed.") |
|
|
|
|
|
fig = go.Figure(data=[go.Scatter3d( |
|
x=df['tsne-3d-one'], |
|
y=df['tsne-3d-two'], |
|
z=df['tsne-3d-three'], |
|
mode='markers', |
|
hovertext=df['hovertext'], |
|
hoverinfo='text', |
|
marker=dict( |
|
size=1, |
|
color=df['cluster'], |
|
colorscale='Viridis', |
|
opacity=0.8 |
|
) |
|
)]) |
|
|
|
fig.update_layout( |
|
plot_bgcolor='#202020', |
|
height=800, |
|
margin=dict(l=0, r=0, b=0, t=0), |
|
scene=dict( |
|
xaxis=dict(showbackground=True, backgroundcolor="#000000"), |
|
yaxis=dict(showbackground=True, backgroundcolor="#000000"), |
|
zaxis=dict(showbackground=True, backgroundcolor="#000000"), |
|
), |
|
scene_camera=dict(eye=dict(x=0.001, y=0.001, z=0.001)) |
|
) |
|
st.session_state.fig = fig |
|
|
|
|
|
if 'data_loaded' in st.session_state and st.session_state.data_loaded: |
|
st.plotly_chart(st.session_state.fig, use_container_width=True) |
|
|
|
|
|
|
|
if 'df' in st.session_state: |
|
|
|
with st.sidebar: |
|
st.sidebar.markdown("### Query Embeddings") |
|
query = st.text_input("Enter your query:") |
|
if st.button("Search"): |
|
|
|
print("Initializing model...") |
|
model = FlagModel('BAAI/bge-small-en-v1.5', |
|
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", |
|
use_fp16=True) |
|
print("Model initialized.") |
|
|
|
query_embedding = model.encode([query]) |
|
|
|
scores_title, retrieved_examples_title = st.session_state.dataclysm_title_indexed.get_nearest_examples('title_embedding', query_embedding, k=10) |
|
df_query = pd.DataFrame(retrieved_examples_title) |
|
df_query['proximity'] = scores_title |
|
df_query = df_query.sort_values(by='proximity', ascending=True) |
|
|
|
df_query['proximity'] = df_query['proximity'].round(3) |
|
|
|
df_query['URL'] = df_query['id'].apply(lambda x: f'<a href="https://arxiv.org/abs/{x}" target="_blank">Link</a>') |
|
st.sidebar.markdown(df_query[['title', 'proximity', 'id']].to_html(escape=False), unsafe_allow_html=True) |
|
st.sidebar.markdown("# Detailed View") |
|
selected_index = st.sidebar.selectbox("Select Key", st.session_state.df.id) |
|
|
|
|
|
selected_row = st.session_state.df[st.session_state.df['id'] == selected_index].iloc[0] |
|
st.markdown(f"### Title\n{selected_row['title']}", unsafe_allow_html=True) |
|
st.markdown(f"### Abstract\n{selected_row['abstract']}", unsafe_allow_html=True) |
|
st.markdown(f"[Read the full paper](https://arxiv.org/abs/{selected_row['id']})", unsafe_allow_html=True) |
|
st.markdown(f"[Download PDF](https://arxiv.org/pdf/{selected_row['id']})", unsafe_allow_html=True) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|