added also medrxiv
Browse files- combine_databases.py +136 -0
- streamlit_app.py +67 -28
- update_database_medarxiv.py +289 -0
combine_databases.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
def combine_databases():
|
6 |
+
# Define paths
|
7 |
+
aggregated_data_path = Path("aggregated_data")
|
8 |
+
db_update_bio_path = Path("db_update")
|
9 |
+
biorxiv_embeddings_path = Path("biorxiv_ubin_embaddings.npy")
|
10 |
+
embed_update_bio_path = Path("embed_update")
|
11 |
+
|
12 |
+
db_update_med_path = Path("db_update_med")
|
13 |
+
embed_update_med_path = Path("embed_update_med")
|
14 |
+
|
15 |
+
# Load existing database and embeddings for BioRxiv
|
16 |
+
df_bio_existing = pd.read_parquet(aggregated_data_path)
|
17 |
+
bio_embeddings_existing = np.load(biorxiv_embeddings_path, allow_pickle=True)
|
18 |
+
print(f"Existing BioRxiv data shape: {df_bio_existing.shape}, Existing BioRxiv embeddings shape: {bio_embeddings_existing.shape}")
|
19 |
+
|
20 |
+
# Determine the embedding size from existing embeddings
|
21 |
+
embedding_size = bio_embeddings_existing.shape[1]
|
22 |
+
|
23 |
+
# Prepare lists to collect new updates
|
24 |
+
bio_dfs_list = []
|
25 |
+
bio_embeddings_list = []
|
26 |
+
|
27 |
+
# Helper function to process updates from a specified directory
|
28 |
+
def process_updates(new_data_directory, updated_embeddings_directory, dfs_list, embeddings_list):
|
29 |
+
new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
|
30 |
+
for data_file in new_data_files:
|
31 |
+
corresponding_embedding_file = Path(updated_embeddings_directory) / (data_file.stem + ".npy")
|
32 |
+
|
33 |
+
if corresponding_embedding_file.exists():
|
34 |
+
df = pd.read_parquet(data_file)
|
35 |
+
new_embeddings = np.load(corresponding_embedding_file, allow_pickle=True)
|
36 |
+
|
37 |
+
# Check if the number of rows in the DataFrame matches the number of rows in the embeddings
|
38 |
+
if df.shape[0] != new_embeddings.shape[0]:
|
39 |
+
print(f"Shape mismatch for {data_file.name}: DataFrame has {df.shape[0]} rows, embeddings have {new_embeddings.shape[0]} rows. Skipping.")
|
40 |
+
continue
|
41 |
+
|
42 |
+
# Check embedding size and adjust if necessary
|
43 |
+
if new_embeddings.shape[1] != embedding_size:
|
44 |
+
print(f"Skipping {data_file.name} due to embedding size mismatch.")
|
45 |
+
continue
|
46 |
+
|
47 |
+
dfs_list.append(df)
|
48 |
+
embeddings_list.append(new_embeddings)
|
49 |
+
else:
|
50 |
+
print(f"No corresponding embedding file found for {data_file.name}")
|
51 |
+
|
52 |
+
# Process updates from both BioRxiv and MedRxiv
|
53 |
+
process_updates(db_update_bio_path, embed_update_bio_path, bio_dfs_list, bio_embeddings_list)
|
54 |
+
|
55 |
+
# Concatenate all BioRxiv updates
|
56 |
+
if bio_dfs_list:
|
57 |
+
df_bio_updates = pd.concat(bio_dfs_list)
|
58 |
+
else:
|
59 |
+
df_bio_updates = pd.DataFrame()
|
60 |
+
|
61 |
+
if bio_embeddings_list:
|
62 |
+
bio_embeddings_updates = np.vstack(bio_embeddings_list)
|
63 |
+
else:
|
64 |
+
bio_embeddings_updates = np.array([])
|
65 |
+
|
66 |
+
# Append new BioRxiv data to existing, handling duplicates as needed
|
67 |
+
df_bio_combined = pd.concat([df_bio_existing, df_bio_updates])
|
68 |
+
|
69 |
+
# Create a mask for filtering unique titles
|
70 |
+
bio_mask = ~df_bio_combined.duplicated(subset=["title"], keep="last")
|
71 |
+
df_bio_combined = df_bio_combined[bio_mask]
|
72 |
+
|
73 |
+
# Combine BioRxiv embeddings, ensuring alignment with the DataFrame
|
74 |
+
bio_embeddings_combined = (
|
75 |
+
np.vstack([bio_embeddings_existing, bio_embeddings_updates])
|
76 |
+
if bio_embeddings_updates.size
|
77 |
+
else bio_embeddings_existing
|
78 |
+
)
|
79 |
+
|
80 |
+
# Filter the embeddings based on the DataFrame unique entries
|
81 |
+
bio_embeddings_combined = bio_embeddings_combined[bio_mask]
|
82 |
+
|
83 |
+
assert df_bio_combined.shape[0] == bio_embeddings_combined.shape[0], "Shape mismatch between BioRxiv DataFrame and embeddings"
|
84 |
+
|
85 |
+
print(f"Filtered BioRxiv DataFrame shape: {df_bio_combined.shape}")
|
86 |
+
print(f"Filtered BioRxiv embeddings shape: {bio_embeddings_combined.shape}")
|
87 |
+
|
88 |
+
# Save combined BioRxiv DataFrame and embeddings
|
89 |
+
combined_biorxiv_data_path = aggregated_data_path / "combined_biorxiv_data.parquet"
|
90 |
+
df_bio_combined.to_parquet(combined_biorxiv_data_path)
|
91 |
+
print(f"Saved combined BioRxiv DataFrame to {combined_biorxiv_data_path}")
|
92 |
+
|
93 |
+
combined_biorxiv_embeddings_path = "biorxiv_ubin_embaddings.npy"
|
94 |
+
np.save(combined_biorxiv_embeddings_path, bio_embeddings_combined)
|
95 |
+
print(f"Saved combined BioRxiv embeddings to {combined_biorxiv_embeddings_path}")
|
96 |
+
|
97 |
+
# Prepare lists to collect new MedRxiv updates
|
98 |
+
med_dfs_list = []
|
99 |
+
med_embeddings_list = []
|
100 |
+
|
101 |
+
process_updates(db_update_med_path, embed_update_med_path, med_dfs_list, med_embeddings_list)
|
102 |
+
|
103 |
+
# Concatenate all MedRxiv updates
|
104 |
+
if med_dfs_list:
|
105 |
+
df_med_combined = pd.concat(med_dfs_list)
|
106 |
+
else:
|
107 |
+
df_med_combined = pd.DataFrame()
|
108 |
+
|
109 |
+
if med_embeddings_list:
|
110 |
+
med_embeddings_combined = np.vstack(med_embeddings_list)
|
111 |
+
else:
|
112 |
+
med_embeddings_combined = np.array([])
|
113 |
+
|
114 |
+
last_date_in_med_database = df_med_combined['date'].max() if not df_med_combined.empty else "unknown"
|
115 |
+
|
116 |
+
# Create a mask for filtering unique titles
|
117 |
+
med_mask = ~df_med_combined.duplicated(subset=["title"], keep="last")
|
118 |
+
df_med_combined = df_med_combined[med_mask]
|
119 |
+
med_embeddings_combined = med_embeddings_combined[med_mask]
|
120 |
+
|
121 |
+
assert df_med_combined.shape[0] == med_embeddings_combined.shape[0], "Shape mismatch between MedRxiv DataFrame and embeddings"
|
122 |
+
|
123 |
+
print(f"Filtered MedRxiv DataFrame shape: {df_med_combined.shape}")
|
124 |
+
print(f"Filtered MedRxiv embeddings shape: {med_embeddings_combined.shape}")
|
125 |
+
|
126 |
+
# Save combined MedRxiv DataFrame and embeddings
|
127 |
+
combined_medrxiv_data_path = db_update_med_path / f"database_{last_date_in_med_database}.parquet"
|
128 |
+
df_med_combined.to_parquet(combined_medrxiv_data_path)
|
129 |
+
print(f"Saved combined MedRxiv DataFrame to {combined_medrxiv_data_path}")
|
130 |
+
|
131 |
+
combined_medrxiv_embeddings_path = embed_update_med_path / f"database_{last_date_in_med_database}.npy"
|
132 |
+
np.save(combined_medrxiv_embeddings_path, med_embeddings_combined)
|
133 |
+
print(f"Saved combined MedRxiv embeddings to {combined_medrxiv_embeddings_path}")
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
combine_databases()
|
streamlit_app.py
CHANGED
@@ -263,32 +263,57 @@ def download_data_from_dropbox():
|
|
263 |
@st.cache_resource(ttl="1d")
|
264 |
def load_data_embeddings():
|
265 |
existing_data_path = "aggregated_data"
|
266 |
-
|
267 |
existing_embeddings_path = "biorxiv_ubin_embaddings.npy"
|
268 |
-
|
269 |
|
|
|
|
|
|
|
270 |
# Load existing database and embeddings
|
271 |
df_existing = pd.read_parquet(existing_data_path)
|
272 |
embeddings_existing = np.load(existing_embeddings_path, allow_pickle=True)
|
273 |
|
|
|
|
|
|
|
|
|
|
|
274 |
# Prepare lists to collect new updates
|
275 |
df_updates_list = []
|
276 |
embeddings_updates_list = []
|
277 |
|
278 |
-
#
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
embeddings_updates_list.append(np.load(corresponding_embedding_file))
|
290 |
-
else:
|
291 |
-
print(f"No corresponding embedding file found for {data_file.name}")
|
292 |
|
293 |
# Concatenate all updates
|
294 |
if df_updates_list:
|
@@ -304,7 +329,7 @@ def load_data_embeddings():
|
|
304 |
# Append new data to existing, handling duplicates as needed
|
305 |
df_combined = pd.concat([df_existing, df_updates])
|
306 |
|
307 |
-
#
|
308 |
mask = ~df_combined.duplicated(subset=["title"], keep="last")
|
309 |
df_combined = df_combined[mask]
|
310 |
|
@@ -315,20 +340,21 @@ def load_data_embeddings():
|
|
315 |
else embeddings_existing
|
316 |
)
|
317 |
|
318 |
-
#
|
319 |
embeddings_combined = embeddings_combined[mask]
|
320 |
|
321 |
return df_combined, embeddings_combined
|
322 |
|
|
|
323 |
LLM_prompt = "Review the abstracts listed below and create a list and summary that captures their main themes and findings. Identify any commonalities across the abstracts and highlight these in your summary. Ensure your response is concise, avoids external links, and is formatted in markdown.\n\n"
|
324 |
|
325 |
-
def summarize_abstract(abstract, llm_model="
|
326 |
"""
|
327 |
Summarizes the provided abstract using a specified LLM model.
|
328 |
|
329 |
Parameters:
|
330 |
- abstract (str): The abstract text to be summarized.
|
331 |
-
- llm_model (str): The LLM model used for summarization. Defaults to "
|
332 |
|
333 |
Returns:
|
334 |
- str: A summary of the abstract, condensed into one to two sentences.
|
@@ -390,16 +416,25 @@ def define_style():
|
|
390 |
)
|
391 |
|
392 |
|
393 |
-
def logo(db_update_date,
|
394 |
# Initialize Streamlit app
|
395 |
-
|
|
|
396 |
st.markdown(
|
397 |
f"""
|
398 |
-
<div style='
|
399 |
-
<
|
400 |
-
|
401 |
-
|
|
|
|
|
|
|
402 |
</div>
|
|
|
|
|
|
|
|
|
|
|
403 |
""",
|
404 |
unsafe_allow_html=True,
|
405 |
)
|
@@ -413,7 +448,7 @@ download_data_from_dropbox()
|
|
413 |
define_style()
|
414 |
|
415 |
df, embeddings_unique = load_data_embeddings()
|
416 |
-
logo(df["date"].max(), df.shape[0])
|
417 |
|
418 |
# model = model_to_device()
|
419 |
|
@@ -474,7 +509,7 @@ if query:
|
|
474 |
)
|
475 |
|
476 |
# Prepare the results for plotting
|
477 |
-
plot_data = {"Date": [], "Title": [], "Score": [], "DOI": [], "category": []}
|
478 |
|
479 |
search_df = pd.DataFrame(results[0])
|
480 |
|
@@ -503,11 +538,14 @@ if query:
|
|
503 |
plot_data["Score"].append(search_df["score"][index]) # type: ignore
|
504 |
plot_data["DOI"].append(row["doi"])
|
505 |
plot_data["category"].append(row["category"])
|
|
|
506 |
|
507 |
#summary_text = summarize_abstract(row['abstract'])
|
508 |
|
509 |
with st.expander(f"{index+1}\. {row['title']}"): # type: ignore
|
510 |
-
st.
|
|
|
|
|
511 |
st.markdown(f"**Authors:** {row['authors']}")
|
512 |
col1, col2 = st.columns(2)
|
513 |
col2.markdown(f"**Category:** {row['category']}")
|
@@ -556,6 +594,7 @@ if query:
|
|
556 |
x="Date",
|
557 |
y="Score",
|
558 |
hover_data=["Title", "DOI"],
|
|
|
559 |
title="Publication Times and Scores",
|
560 |
)
|
561 |
fig.update_traces(marker=dict(size=10))
|
|
|
263 |
@st.cache_resource(ttl="1d")
|
264 |
def load_data_embeddings():
|
265 |
existing_data_path = "aggregated_data"
|
266 |
+
new_data_directory_bio = "db_update"
|
267 |
existing_embeddings_path = "biorxiv_ubin_embaddings.npy"
|
268 |
+
updated_embeddings_directory_bio = "embed_update"
|
269 |
|
270 |
+
new_data_directory_med = "db_update_med"
|
271 |
+
updated_embeddings_directory_med = "embed_update_med"
|
272 |
+
|
273 |
# Load existing database and embeddings
|
274 |
df_existing = pd.read_parquet(existing_data_path)
|
275 |
embeddings_existing = np.load(existing_embeddings_path, allow_pickle=True)
|
276 |
|
277 |
+
print(f"Existing data shape: {df_existing.shape}, Existing embeddings shape: {embeddings_existing.shape}")
|
278 |
+
|
279 |
+
# Determine the embedding size from existing embeddings
|
280 |
+
embedding_size = embeddings_existing.shape[1]
|
281 |
+
|
282 |
# Prepare lists to collect new updates
|
283 |
df_updates_list = []
|
284 |
embeddings_updates_list = []
|
285 |
|
286 |
+
# Helper function to process updates from a specified directory
|
287 |
+
def process_updates(new_data_directory, updated_embeddings_directory):
|
288 |
+
new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
|
289 |
+
print(new_data_files)
|
290 |
+
for data_file in new_data_files:
|
291 |
+
corresponding_embedding_file = Path(updated_embeddings_directory) / (
|
292 |
+
data_file.stem + ".npy"
|
293 |
+
)
|
294 |
+
|
295 |
+
if corresponding_embedding_file.exists():
|
296 |
+
df = pd.read_parquet(data_file)
|
297 |
+
new_embeddings = np.load(corresponding_embedding_file, allow_pickle=True)
|
298 |
+
|
299 |
+
# Check if the number of rows in the DataFrame matches the number of rows in the embeddings
|
300 |
+
if df.shape[0] != new_embeddings.shape[0]:
|
301 |
+
print(f"Shape mismatch for {data_file.name}: DataFrame has {df.shape[0]} rows, embeddings have {new_embeddings.shape[0]} rows. Skipping.")
|
302 |
+
continue
|
303 |
+
|
304 |
+
# Check embedding size and adjust if necessary
|
305 |
+
if new_embeddings.shape[1] != embedding_size:
|
306 |
+
print(f"Skipping {data_file.name} due to embedding size mismatch.")
|
307 |
+
continue
|
308 |
+
|
309 |
+
df_updates_list.append(df)
|
310 |
+
embeddings_updates_list.append(new_embeddings)
|
311 |
+
else:
|
312 |
+
print(f"No corresponding embedding file found for {data_file.name}")
|
313 |
|
314 |
+
# Process updates from both BioRxiv and MedArXiv
|
315 |
+
process_updates(new_data_directory_bio, updated_embeddings_directory_bio)
|
316 |
+
process_updates(new_data_directory_med, updated_embeddings_directory_med)
|
|
|
|
|
|
|
317 |
|
318 |
# Concatenate all updates
|
319 |
if df_updates_list:
|
|
|
329 |
# Append new data to existing, handling duplicates as needed
|
330 |
df_combined = pd.concat([df_existing, df_updates])
|
331 |
|
332 |
+
# Create a mask for filtering
|
333 |
mask = ~df_combined.duplicated(subset=["title"], keep="last")
|
334 |
df_combined = df_combined[mask]
|
335 |
|
|
|
340 |
else embeddings_existing
|
341 |
)
|
342 |
|
343 |
+
# Filter the embeddings based on the dataframe unique entries
|
344 |
embeddings_combined = embeddings_combined[mask]
|
345 |
|
346 |
return df_combined, embeddings_combined
|
347 |
|
348 |
+
|
349 |
LLM_prompt = "Review the abstracts listed below and create a list and summary that captures their main themes and findings. Identify any commonalities across the abstracts and highlight these in your summary. Ensure your response is concise, avoids external links, and is formatted in markdown.\n\n"
|
350 |
|
351 |
+
def summarize_abstract(abstract, llm_model="llama-3.1-70b-versatile", instructions=LLM_prompt, api_key=st.secrets["groq_token"]):
|
352 |
"""
|
353 |
Summarizes the provided abstract using a specified LLM model.
|
354 |
|
355 |
Parameters:
|
356 |
- abstract (str): The abstract text to be summarized.
|
357 |
+
- llm_model (str): The LLM model used for summarization. Defaults to "llama-3.1-70b-versatile".
|
358 |
|
359 |
Returns:
|
360 |
- str: A summary of the abstract, condensed into one to two sentences.
|
|
|
416 |
)
|
417 |
|
418 |
|
419 |
+
def logo(db_update_date, db_size_bio, db_size_med):
|
420 |
# Initialize Streamlit app
|
421 |
+
biorxiv_logo = "https://www.biorxiv.org/sites/default/files/biorxiv_logo_homepage.png"
|
422 |
+
medarxiv_logo = "https://www.medrxiv.org/sites/default/files/medRxiv_homepage_logo.png"
|
423 |
st.markdown(
|
424 |
f"""
|
425 |
+
<div style='display: flex; justify-content: center; align-items: center;'>
|
426 |
+
<div style='margin-right: 20px;'>
|
427 |
+
<img src='{biorxiv_logo}' alt='BioRxiv logo' style='max-height: 100px;'>
|
428 |
+
</div>
|
429 |
+
<div style='margin-left: 20px;'>
|
430 |
+
<img src='{medarxiv_logo}' alt='medRxiv logo' style='max-height: 100px;'>
|
431 |
+
</div>
|
432 |
</div>
|
433 |
+
<div style='text-align: center; margin-top: 10px;'>
|
434 |
+
<h3 style='color: black;'>Manuscript Semantic Search [bMSS]</h3>
|
435 |
+
Last database update: {db_update_date}; Database size: bioRxiv: {db_size_bio} / medRxiv: {db_size_med} entries
|
436 |
+
</div>
|
437 |
+
<br>
|
438 |
""",
|
439 |
unsafe_allow_html=True,
|
440 |
)
|
|
|
448 |
define_style()
|
449 |
|
450 |
df, embeddings_unique = load_data_embeddings()
|
451 |
+
logo(df["date"].max(), df[df['server']=='biorxiv'].shape[0], df[df['server']=='medrxiv'].shape[0])
|
452 |
|
453 |
# model = model_to_device()
|
454 |
|
|
|
509 |
)
|
510 |
|
511 |
# Prepare the results for plotting
|
512 |
+
plot_data = {"Date": [], "Title": [], "Score": [], "DOI": [], "category": [], "server": []}
|
513 |
|
514 |
search_df = pd.DataFrame(results[0])
|
515 |
|
|
|
538 |
plot_data["Score"].append(search_df["score"][index]) # type: ignore
|
539 |
plot_data["DOI"].append(row["doi"])
|
540 |
plot_data["category"].append(row["category"])
|
541 |
+
plot_data["server"].append(row["server"])
|
542 |
|
543 |
#summary_text = summarize_abstract(row['abstract'])
|
544 |
|
545 |
with st.expander(f"{index+1}\. {row['title']}"): # type: ignore
|
546 |
+
col1, col2 = st.columns(2)
|
547 |
+
col1.markdown(f"**Score:** {entry['score']:.1f}")
|
548 |
+
col2.markdown(f"**Server:** [{row['server']}]")
|
549 |
st.markdown(f"**Authors:** {row['authors']}")
|
550 |
col1, col2 = st.columns(2)
|
551 |
col2.markdown(f"**Category:** {row['category']}")
|
|
|
594 |
x="Date",
|
595 |
y="Score",
|
596 |
hover_data=["Title", "DOI"],
|
597 |
+
color='server',
|
598 |
title="Publication Times and Scores",
|
599 |
)
|
600 |
fig.update_traces(marker=dict(size=10))
|
update_database_medarxiv.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
import datetime
|
5 |
+
import requests
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from datetime import datetime
|
9 |
+
from dateutil.relativedelta import relativedelta
|
10 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
import torch
|
13 |
+
import shutil
|
14 |
+
import dropbox
|
15 |
+
import streamlit as st
|
16 |
+
import time
|
17 |
+
|
18 |
+
def retry_on_exception(exception, retries=5, delay=2):
|
19 |
+
def decorator(func):
|
20 |
+
def wrapper(*args, **kwargs):
|
21 |
+
last_exception = None
|
22 |
+
for _ in range(retries):
|
23 |
+
try:
|
24 |
+
return func(*args, **kwargs)
|
25 |
+
except exception as e:
|
26 |
+
last_exception = e
|
27 |
+
print(f"Retrying due to: {str(e)}")
|
28 |
+
time.sleep(delay)
|
29 |
+
raise last_exception
|
30 |
+
return wrapper
|
31 |
+
return decorator
|
32 |
+
|
33 |
+
@retry_on_exception(requests.exceptions.ConnectionError)
|
34 |
+
def fetch_and_save_data_block(endpoint, server, block_start, block_end, save_directory, format='json'):
|
35 |
+
base_url = f"https://api.medrxiv.org/details/{server}/"
|
36 |
+
block_interval = f"{block_start.strftime('%Y-%m-%d')}/{block_end.strftime('%Y-%m-%d')}"
|
37 |
+
block_data = []
|
38 |
+
cursor = 0
|
39 |
+
continue_fetching = True
|
40 |
+
|
41 |
+
while continue_fetching:
|
42 |
+
url = f"{base_url}{block_interval}/{cursor}/{format}"
|
43 |
+
response = requests.get(url)
|
44 |
+
|
45 |
+
if response.status_code != 200:
|
46 |
+
print(f"Failed to fetch data for block {block_interval} at cursor {cursor}. HTTP Status: {response.status_code}")
|
47 |
+
break
|
48 |
+
|
49 |
+
data = response.json()
|
50 |
+
fetched_papers = len(data['collection'])
|
51 |
+
|
52 |
+
if fetched_papers > 0:
|
53 |
+
block_data.extend(data['collection'])
|
54 |
+
cursor += fetched_papers
|
55 |
+
print(f"Fetched {fetched_papers} papers for block {block_interval}. Total fetched: {cursor}.")
|
56 |
+
else:
|
57 |
+
continue_fetching = False
|
58 |
+
|
59 |
+
if block_data:
|
60 |
+
save_data_block(block_data, block_start, block_end, endpoint, save_directory)
|
61 |
+
|
62 |
+
def save_data_block(block_data, start_date, end_date, endpoint, save_directory):
|
63 |
+
start_yymmdd = start_date.strftime("%y%m%d")
|
64 |
+
end_yymmdd = end_date.strftime("%y%m%d")
|
65 |
+
filename = f"{save_directory}/{endpoint}_data_{start_yymmdd}_{end_yymmdd}.json"
|
66 |
+
|
67 |
+
with open(filename, 'w') as file:
|
68 |
+
json.dump(block_data, file, indent=4)
|
69 |
+
|
70 |
+
print(f"Saved data block to {filename}")
|
71 |
+
|
72 |
+
def fetch_data(endpoint, server, interval, save_directory, format='json'):
|
73 |
+
os.makedirs(save_directory, exist_ok=True)
|
74 |
+
start_date, end_date = [datetime.strptime(date, "%Y-%m-%d") for date in interval.split('/')]
|
75 |
+
current_date = start_date
|
76 |
+
tasks = []
|
77 |
+
|
78 |
+
with ThreadPoolExecutor(max_workers=12) as executor:
|
79 |
+
while current_date <= end_date:
|
80 |
+
block_start = current_date
|
81 |
+
block_end = min(current_date + relativedelta(months=1) - relativedelta(days=1), end_date)
|
82 |
+
tasks.append(executor.submit(fetch_and_save_data_block, endpoint, server, block_start, block_end, save_directory, format))
|
83 |
+
current_date += relativedelta(months=1)
|
84 |
+
|
85 |
+
for future in as_completed(tasks):
|
86 |
+
future.result()
|
87 |
+
|
88 |
+
def load_json_to_dataframe(json_file):
|
89 |
+
with open(json_file, 'r') as file:
|
90 |
+
data = json.load(file)
|
91 |
+
return pd.DataFrame(data)
|
92 |
+
|
93 |
+
def save_dataframe(df, save_path):
|
94 |
+
df.to_parquet(save_path)
|
95 |
+
|
96 |
+
def process_json_files(directory, save_directory):
|
97 |
+
os.makedirs(save_directory, exist_ok=True)
|
98 |
+
|
99 |
+
json_files = list(Path(directory).glob('*.json'))
|
100 |
+
print(f'json_files {type(json_files)}: {json_files}')
|
101 |
+
|
102 |
+
for json_file in json_files:
|
103 |
+
df = load_json_to_dataframe(json_file)
|
104 |
+
|
105 |
+
parquet_filename = f"{json_file.stem}.parquet"
|
106 |
+
save_path = os.path.join(save_directory, parquet_filename)
|
107 |
+
|
108 |
+
if os.path.exists(save_path):
|
109 |
+
npy_file_path = save_path.replace('db_update', 'embed_update').replace('parquet', 'npy')
|
110 |
+
if os.path.exists(npy_file_path):
|
111 |
+
os.remove(npy_file_path)
|
112 |
+
print(f'Removed embedding file {npy_file_path} due to the dataframe update')
|
113 |
+
|
114 |
+
save_dataframe(df, save_path)
|
115 |
+
print(f"Processed and saved {json_file.name} to {parquet_filename}")
|
116 |
+
|
117 |
+
def load_unprocessed_parquets(db_update_directory, embed_update_directory):
|
118 |
+
db_update_directory = Path(db_update_directory)
|
119 |
+
embed_update_directory = Path(embed_update_directory)
|
120 |
+
|
121 |
+
parquet_files = list(db_update_directory.glob('*.parquet'))
|
122 |
+
npy_files = {f.stem for f in embed_update_directory.glob('*.npy')}
|
123 |
+
unprocessed_dataframes = []
|
124 |
+
|
125 |
+
for parquet_file in parquet_files:
|
126 |
+
if parquet_file.stem not in npy_files:
|
127 |
+
unprocessed_dataframes.append(parquet_file)
|
128 |
+
print(f"Loaded unprocessed Parquet file: {parquet_file.name}")
|
129 |
+
else:
|
130 |
+
print(f"Skipping processed Parquet file: {parquet_file.name}")
|
131 |
+
|
132 |
+
return unprocessed_dataframes
|
133 |
+
|
134 |
+
def connect_to_dropbox():
|
135 |
+
dropbox_APP_KEY = st.secrets["dropbox_APP_KEY"]
|
136 |
+
dropbox_APP_SECRET = st.secrets["dropbox_APP_SECRET"]
|
137 |
+
dropbox_REFRESH_TOKEN = st.secrets["dropbox_REFRESH_TOKEN"]
|
138 |
+
|
139 |
+
dbx = dropbox.Dropbox(
|
140 |
+
app_key=dropbox_APP_KEY,
|
141 |
+
app_secret=dropbox_APP_SECRET,
|
142 |
+
oauth2_refresh_token=dropbox_REFRESH_TOKEN
|
143 |
+
)
|
144 |
+
return dbx
|
145 |
+
|
146 |
+
def upload_path(local_path, dropbox_path):
|
147 |
+
dbx = connect_to_dropbox()
|
148 |
+
local_path = Path(local_path)
|
149 |
+
|
150 |
+
if local_path.is_file():
|
151 |
+
relative_path = local_path.name
|
152 |
+
dropbox_file_path = os.path.join(dropbox_path, relative_path).replace('\\', '/').replace('//', '/')
|
153 |
+
upload_file(local_path, dropbox_file_path, dbx)
|
154 |
+
elif local_path.is_dir():
|
155 |
+
for local_file in local_path.rglob('*'):
|
156 |
+
if local_file.is_file():
|
157 |
+
relative_path = local_file.relative_to(local_path.parent)
|
158 |
+
dropbox_file_path = os.path.join(dropbox_path, relative_path).replace('\\', '/').replace('//', '/')
|
159 |
+
upload_file(local_file, dropbox_file_path, dbx)
|
160 |
+
else:
|
161 |
+
print("The provided path does not exist.")
|
162 |
+
|
163 |
+
def upload_file(file_path, dropbox_file_path, dbx):
|
164 |
+
try:
|
165 |
+
dropbox_file_path = dropbox_file_path.replace('\\', '/')
|
166 |
+
|
167 |
+
try:
|
168 |
+
metadata = dbx.files_get_metadata(dropbox_file_path)
|
169 |
+
dropbox_mod_time = metadata.server_modified
|
170 |
+
local_mod_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
171 |
+
|
172 |
+
if dropbox_mod_time >= local_mod_time:
|
173 |
+
print(f"Skipped {dropbox_file_path}, Dropbox version is up-to-date.")
|
174 |
+
return
|
175 |
+
except dropbox.exceptions.ApiError as e:
|
176 |
+
if not isinstance(e.error, dropbox.files.GetMetadataError) or e.error.is_path() and e.error.get_path().is_not_found():
|
177 |
+
print(f"No existing file on Dropbox, proceeding with upload: {dropbox_file_path}")
|
178 |
+
else:
|
179 |
+
raise e
|
180 |
+
|
181 |
+
with file_path.open('rb') as f:
|
182 |
+
dbx.files_upload(f.read(), dropbox_file_path, mode=dropbox.files.WriteMode.overwrite)
|
183 |
+
print(f"Uploaded {dropbox_file_path}")
|
184 |
+
except Exception as e:
|
185 |
+
print(f"Failed to upload {dropbox_file_path}: {str(e)}")
|
186 |
+
|
187 |
+
def load_data_embeddings():
|
188 |
+
new_data_directory = "db_update_med"
|
189 |
+
updated_embeddings_directory = "embed_update_med"
|
190 |
+
new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
|
191 |
+
|
192 |
+
df_updates_list = []
|
193 |
+
embeddings_updates_list = []
|
194 |
+
|
195 |
+
for data_file in new_data_files:
|
196 |
+
# Assuming naming convention allows direct correlation
|
197 |
+
corresponding_embedding_file = Path(updated_embeddings_directory) / (
|
198 |
+
data_file.stem + ".npy"
|
199 |
+
)
|
200 |
+
|
201 |
+
if corresponding_embedding_file.exists():
|
202 |
+
# Load and append DataFrame and embeddings
|
203 |
+
df_updates_list.append(pd.read_parquet(data_file))
|
204 |
+
embeddings_updates_list.append(np.load(corresponding_embedding_file))
|
205 |
+
else:
|
206 |
+
print(f"No corresponding embedding file found for {data_file.name}")
|
207 |
+
|
208 |
+
new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
|
209 |
+
for data_file in new_data_files:
|
210 |
+
corresponding_embedding_file = Path(updated_embeddings_directory) / (
|
211 |
+
data_file.stem + ".npy"
|
212 |
+
)
|
213 |
+
|
214 |
+
if corresponding_embedding_file.exists():
|
215 |
+
df_updates_list.append(pd.read_parquet(data_file))
|
216 |
+
embeddings_updates_list.append(np.load(corresponding_embedding_file))
|
217 |
+
else:
|
218 |
+
print(f"No corresponding embedding file found for {data_file.name}")
|
219 |
+
|
220 |
+
if df_updates_list:
|
221 |
+
df_updates = pd.concat(df_updates_list)
|
222 |
+
else:
|
223 |
+
df_updates = pd.DataFrame()
|
224 |
+
|
225 |
+
if embeddings_updates_list:
|
226 |
+
embeddings_updates = np.vstack(embeddings_updates_list)
|
227 |
+
else:
|
228 |
+
embeddings_updates = np.array([])
|
229 |
+
|
230 |
+
df_combined = df_updates
|
231 |
+
mask = ~df_combined.duplicated(subset=["title"], keep="last")
|
232 |
+
df_combined = df_combined[mask]
|
233 |
+
|
234 |
+
embeddings_combined = embeddings_updates
|
235 |
+
|
236 |
+
embeddings_combined = embeddings_combined[mask]
|
237 |
+
|
238 |
+
return df_combined, embeddings_combined
|
239 |
+
|
240 |
+
endpoint = "details"
|
241 |
+
server = "medrxiv"
|
242 |
+
|
243 |
+
df, embeddings = load_data_embeddings()
|
244 |
+
|
245 |
+
try:
|
246 |
+
start_date = df['date'].max()
|
247 |
+
except:
|
248 |
+
start_date = '1990-01-01'
|
249 |
+
last_date = datetime.today().strftime('%Y-%m-%d')
|
250 |
+
|
251 |
+
interval = f'{start_date}/{last_date}'
|
252 |
+
print(f'using interval: {interval}')
|
253 |
+
|
254 |
+
save_directory = "db_update_json_med"
|
255 |
+
fetch_data(endpoint, server, interval, save_directory)
|
256 |
+
|
257 |
+
directory = r'db_update_json_med'
|
258 |
+
save_directory = r'db_update_med'
|
259 |
+
process_json_files(directory, save_directory)
|
260 |
+
|
261 |
+
db_update_directory = 'db_update_med'
|
262 |
+
embed_update_directory = 'embed_update_med'
|
263 |
+
unprocessed_dataframes = load_unprocessed_parquets(db_update_directory, embed_update_directory)
|
264 |
+
|
265 |
+
if unprocessed_dataframes:
|
266 |
+
for file in unprocessed_dataframes:
|
267 |
+
df = pd.read_parquet(file)
|
268 |
+
query = df['abstract'].tolist()
|
269 |
+
|
270 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
271 |
+
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
272 |
+
model.to(device)
|
273 |
+
|
274 |
+
query_embedding = model.encode(query, normalize_embeddings=True, precision='ubinary', show_progress_bar=True)
|
275 |
+
file_path = os.path.basename(file).split('.')[0]
|
276 |
+
os.makedirs('embed_update_med', exist_ok=True)
|
277 |
+
embeddings_path = f'embed_update_med/{file_path}'
|
278 |
+
np.save(embeddings_path, query_embedding)
|
279 |
+
print(f'Saved embeddings {embeddings_path}')
|
280 |
+
|
281 |
+
db_update_json = 'db_update_json_med'
|
282 |
+
shutil.rmtree(db_update_json)
|
283 |
+
print(f"Directory '{db_update_json}' and its contents have been removed.")
|
284 |
+
|
285 |
+
for path in ['db_update_med', 'embed_update_med']:
|
286 |
+
upload_path(path, '/')
|
287 |
+
|
288 |
+
else:
|
289 |
+
print('Nothing to do')
|