import os import re import numpy as np import pandas as pd import plotly.express as px import streamlit as st st.set_page_config(layout="wide") DATA_FILE = "data/gwf_2017-2021_specter2_base.json" THEMES = {"cluster": "fall", "year": "mint", "source": "phase"} st.markdown( """ """, unsafe_allow_html=True, ) st.sidebar.write( """

gwf-spectrum

""", unsafe_allow_html=True, ) st.sidebar.write( """

An interactive t-SNE visualization of spectre2 embeddings featuring over 1K papers (titles and abstracts) from the Global Water Futures program. For more details, check out our README and our step-by-step guide here.

""", unsafe_allow_html=True, ) st.sidebar.markdown( "Happy exploring! :rocket::rocket:" ) def to_string_authors(list_of_authors): if len(list_of_authors) > 6: return ", ".join(list_of_authors[:6]) + ", et al." elif len(list_of_authors) > 2: return ", ".join(list_of_authors[:-1]) + ", and " + list_of_authors[-1] else: return " and ".join(list_of_authors) def load_df(data_file: os.PathLike): df = pd.read_json(data_file, orient="records") df["x"] = df["point2d"].apply(lambda x: x[0]) df["y"] = df["point2d"].apply(lambda x: x[1]) df["year"] = df["year"].replace("", 0) df["year"] = df["year"].astype(int) df["authors_trimmed"] = df.authors.apply( lambda row: to_string_authors( [(x[x.index(",") + 1 :].strip() + " " + x.split(",")[0].strip()) if "," in x else x for x in row] ) ) if "publication_type" in df.columns: df["type"] = df["publication_type"] df = df.drop(columns=["point2d", "publication_type"]) else: df = df.drop(columns=["point2d"]) return df @st.cache_data def load_dataframe(): return load_df(DATA_FILE) DF = load_dataframe() DF["opacity"] = 0.04 min_year, max_year = DF[DF["year"] > 0]["year"].min(), DF[DF["year"] > 0]["year"].max() with st.sidebar: start_year, end_year = st.select_slider( "Publication year", options=[str(y) for y in range(min_year, max_year + 1)], value=(str(min_year), str(max_year)), ) src = st.text_input("Source") author_names = st.text_input("Author names (separated by comma)") title = st.text_input("Title") start_year = int(start_year) end_year = int(end_year) df_mask = (DF["year"] >= start_year) & (DF["year"] <= end_year) if src: df_mask = df_mask & DF.source.apply(lambda x: src.lower() in x.lower()) if author_names: authors = [a.strip() for a in author_names.split(",")] author_mask = DF.authors.apply( lambda row: all(any(re.match(rf".*{a}.*", x, re.IGNORECASE) for x in row) for a in authors) ) df_mask = df_mask & author_mask if title: df_mask = df_mask & DF.title.apply(lambda x: title.lower() in x.lower()) DF.loc[df_mask, "opacity"] = 1.0 st.write(f"Number of points: {DF[df_mask].shape[0]}") color = st.selectbox("Color", ("cluster", "year", "source")) fig = px.scatter( DF, x="x", y="y", opacity=DF["opacity"], color=color, width=1000, height=800, custom_data=("title", "authors_trimmed", "year", "source", "keywords"), color_continuous_scale=THEMES[color], ) fig.update_traces( hovertemplate="%{customdata[0]}
%{customdata[1]}
%{customdata[2]}
%{customdata[3]}
Keywords: %{customdata[4]}" ) fig.update_layout( # margin=dict(l=10, r=10, t=10, b=10), showlegend=False, font=dict( family="Times New Roman", size=30, ), hoverlabel=dict( align="left", font_size=14, font_family="Rockwell", namelength=-1, ), ) fig.update_xaxes(title="") fig.update_yaxes(title="") st.plotly_chart(fig, use_container_width=True)