dzyla commited on
Commit
2ed3459
1 Parent(s): 3e6f92e

added also medrxiv

Browse files
Files changed (3) hide show
  1. combine_databases.py +136 -0
  2. streamlit_app.py +67 -28
  3. update_database_medarxiv.py +289 -0
combine_databases.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from pathlib import Path
4
+
5
+ def combine_databases():
6
+ # Define paths
7
+ aggregated_data_path = Path("aggregated_data")
8
+ db_update_bio_path = Path("db_update")
9
+ biorxiv_embeddings_path = Path("biorxiv_ubin_embaddings.npy")
10
+ embed_update_bio_path = Path("embed_update")
11
+
12
+ db_update_med_path = Path("db_update_med")
13
+ embed_update_med_path = Path("embed_update_med")
14
+
15
+ # Load existing database and embeddings for BioRxiv
16
+ df_bio_existing = pd.read_parquet(aggregated_data_path)
17
+ bio_embeddings_existing = np.load(biorxiv_embeddings_path, allow_pickle=True)
18
+ print(f"Existing BioRxiv data shape: {df_bio_existing.shape}, Existing BioRxiv embeddings shape: {bio_embeddings_existing.shape}")
19
+
20
+ # Determine the embedding size from existing embeddings
21
+ embedding_size = bio_embeddings_existing.shape[1]
22
+
23
+ # Prepare lists to collect new updates
24
+ bio_dfs_list = []
25
+ bio_embeddings_list = []
26
+
27
+ # Helper function to process updates from a specified directory
28
+ def process_updates(new_data_directory, updated_embeddings_directory, dfs_list, embeddings_list):
29
+ new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
30
+ for data_file in new_data_files:
31
+ corresponding_embedding_file = Path(updated_embeddings_directory) / (data_file.stem + ".npy")
32
+
33
+ if corresponding_embedding_file.exists():
34
+ df = pd.read_parquet(data_file)
35
+ new_embeddings = np.load(corresponding_embedding_file, allow_pickle=True)
36
+
37
+ # Check if the number of rows in the DataFrame matches the number of rows in the embeddings
38
+ if df.shape[0] != new_embeddings.shape[0]:
39
+ print(f"Shape mismatch for {data_file.name}: DataFrame has {df.shape[0]} rows, embeddings have {new_embeddings.shape[0]} rows. Skipping.")
40
+ continue
41
+
42
+ # Check embedding size and adjust if necessary
43
+ if new_embeddings.shape[1] != embedding_size:
44
+ print(f"Skipping {data_file.name} due to embedding size mismatch.")
45
+ continue
46
+
47
+ dfs_list.append(df)
48
+ embeddings_list.append(new_embeddings)
49
+ else:
50
+ print(f"No corresponding embedding file found for {data_file.name}")
51
+
52
+ # Process updates from both BioRxiv and MedRxiv
53
+ process_updates(db_update_bio_path, embed_update_bio_path, bio_dfs_list, bio_embeddings_list)
54
+
55
+ # Concatenate all BioRxiv updates
56
+ if bio_dfs_list:
57
+ df_bio_updates = pd.concat(bio_dfs_list)
58
+ else:
59
+ df_bio_updates = pd.DataFrame()
60
+
61
+ if bio_embeddings_list:
62
+ bio_embeddings_updates = np.vstack(bio_embeddings_list)
63
+ else:
64
+ bio_embeddings_updates = np.array([])
65
+
66
+ # Append new BioRxiv data to existing, handling duplicates as needed
67
+ df_bio_combined = pd.concat([df_bio_existing, df_bio_updates])
68
+
69
+ # Create a mask for filtering unique titles
70
+ bio_mask = ~df_bio_combined.duplicated(subset=["title"], keep="last")
71
+ df_bio_combined = df_bio_combined[bio_mask]
72
+
73
+ # Combine BioRxiv embeddings, ensuring alignment with the DataFrame
74
+ bio_embeddings_combined = (
75
+ np.vstack([bio_embeddings_existing, bio_embeddings_updates])
76
+ if bio_embeddings_updates.size
77
+ else bio_embeddings_existing
78
+ )
79
+
80
+ # Filter the embeddings based on the DataFrame unique entries
81
+ bio_embeddings_combined = bio_embeddings_combined[bio_mask]
82
+
83
+ assert df_bio_combined.shape[0] == bio_embeddings_combined.shape[0], "Shape mismatch between BioRxiv DataFrame and embeddings"
84
+
85
+ print(f"Filtered BioRxiv DataFrame shape: {df_bio_combined.shape}")
86
+ print(f"Filtered BioRxiv embeddings shape: {bio_embeddings_combined.shape}")
87
+
88
+ # Save combined BioRxiv DataFrame and embeddings
89
+ combined_biorxiv_data_path = aggregated_data_path / "combined_biorxiv_data.parquet"
90
+ df_bio_combined.to_parquet(combined_biorxiv_data_path)
91
+ print(f"Saved combined BioRxiv DataFrame to {combined_biorxiv_data_path}")
92
+
93
+ combined_biorxiv_embeddings_path = "biorxiv_ubin_embaddings.npy"
94
+ np.save(combined_biorxiv_embeddings_path, bio_embeddings_combined)
95
+ print(f"Saved combined BioRxiv embeddings to {combined_biorxiv_embeddings_path}")
96
+
97
+ # Prepare lists to collect new MedRxiv updates
98
+ med_dfs_list = []
99
+ med_embeddings_list = []
100
+
101
+ process_updates(db_update_med_path, embed_update_med_path, med_dfs_list, med_embeddings_list)
102
+
103
+ # Concatenate all MedRxiv updates
104
+ if med_dfs_list:
105
+ df_med_combined = pd.concat(med_dfs_list)
106
+ else:
107
+ df_med_combined = pd.DataFrame()
108
+
109
+ if med_embeddings_list:
110
+ med_embeddings_combined = np.vstack(med_embeddings_list)
111
+ else:
112
+ med_embeddings_combined = np.array([])
113
+
114
+ last_date_in_med_database = df_med_combined['date'].max() if not df_med_combined.empty else "unknown"
115
+
116
+ # Create a mask for filtering unique titles
117
+ med_mask = ~df_med_combined.duplicated(subset=["title"], keep="last")
118
+ df_med_combined = df_med_combined[med_mask]
119
+ med_embeddings_combined = med_embeddings_combined[med_mask]
120
+
121
+ assert df_med_combined.shape[0] == med_embeddings_combined.shape[0], "Shape mismatch between MedRxiv DataFrame and embeddings"
122
+
123
+ print(f"Filtered MedRxiv DataFrame shape: {df_med_combined.shape}")
124
+ print(f"Filtered MedRxiv embeddings shape: {med_embeddings_combined.shape}")
125
+
126
+ # Save combined MedRxiv DataFrame and embeddings
127
+ combined_medrxiv_data_path = db_update_med_path / f"database_{last_date_in_med_database}.parquet"
128
+ df_med_combined.to_parquet(combined_medrxiv_data_path)
129
+ print(f"Saved combined MedRxiv DataFrame to {combined_medrxiv_data_path}")
130
+
131
+ combined_medrxiv_embeddings_path = embed_update_med_path / f"database_{last_date_in_med_database}.npy"
132
+ np.save(combined_medrxiv_embeddings_path, med_embeddings_combined)
133
+ print(f"Saved combined MedRxiv embeddings to {combined_medrxiv_embeddings_path}")
134
+
135
+ if __name__ == "__main__":
136
+ combine_databases()
streamlit_app.py CHANGED
@@ -263,32 +263,57 @@ def download_data_from_dropbox():
263
  @st.cache_resource(ttl="1d")
264
  def load_data_embeddings():
265
  existing_data_path = "aggregated_data"
266
- new_data_directory = "db_update"
267
  existing_embeddings_path = "biorxiv_ubin_embaddings.npy"
268
- updated_embeddings_directory = "embed_update"
269
 
 
 
 
270
  # Load existing database and embeddings
271
  df_existing = pd.read_parquet(existing_data_path)
272
  embeddings_existing = np.load(existing_embeddings_path, allow_pickle=True)
273
 
 
 
 
 
 
274
  # Prepare lists to collect new updates
275
  df_updates_list = []
276
  embeddings_updates_list = []
277
 
278
- # Ensure pairing of new data and embeddings by their matching filenames
279
- new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
280
- for data_file in new_data_files:
281
- # Assuming naming convention allows direct correlation
282
- corresponding_embedding_file = Path(updated_embeddings_directory) / (
283
- data_file.stem + ".npy"
284
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- if corresponding_embedding_file.exists():
287
- # Load and append DataFrame and embeddings
288
- df_updates_list.append(pd.read_parquet(data_file))
289
- embeddings_updates_list.append(np.load(corresponding_embedding_file))
290
- else:
291
- print(f"No corresponding embedding file found for {data_file.name}")
292
 
293
  # Concatenate all updates
294
  if df_updates_list:
@@ -304,7 +329,7 @@ def load_data_embeddings():
304
  # Append new data to existing, handling duplicates as needed
305
  df_combined = pd.concat([df_existing, df_updates])
306
 
307
- # create a mask for filtering
308
  mask = ~df_combined.duplicated(subset=["title"], keep="last")
309
  df_combined = df_combined[mask]
310
 
@@ -315,20 +340,21 @@ def load_data_embeddings():
315
  else embeddings_existing
316
  )
317
 
318
- # filter the embeddings based on dataframe unique entries
319
  embeddings_combined = embeddings_combined[mask]
320
 
321
  return df_combined, embeddings_combined
322
 
 
323
  LLM_prompt = "Review the abstracts listed below and create a list and summary that captures their main themes and findings. Identify any commonalities across the abstracts and highlight these in your summary. Ensure your response is concise, avoids external links, and is formatted in markdown.\n\n"
324
 
325
- def summarize_abstract(abstract, llm_model="llama3-70b-8192", instructions=LLM_prompt, api_key=st.secrets["groq_token"]):
326
  """
327
  Summarizes the provided abstract using a specified LLM model.
328
 
329
  Parameters:
330
  - abstract (str): The abstract text to be summarized.
331
- - llm_model (str): The LLM model used for summarization. Defaults to "llama3-70b-8192".
332
 
333
  Returns:
334
  - str: A summary of the abstract, condensed into one to two sentences.
@@ -390,16 +416,25 @@ def define_style():
390
  )
391
 
392
 
393
- def logo(db_update_date, db_size):
394
  # Initialize Streamlit app
395
- image_path = "https://www.biorxiv.org/sites/default/files/biorxiv_logo_homepage.png"
 
396
  st.markdown(
397
  f"""
398
- <div style='text-align: center;'>
399
- <img src='{image_path}' alt='BioRxiv logo' style='max-height: 100px;'>
400
- <h3 style='color: black;'>Manuscript Semantic Search [bMSS]</h1>
401
- Last database update: {db_update_date}; Database size: {db_size} entries
 
 
 
402
  </div>
 
 
 
 
 
403
  """,
404
  unsafe_allow_html=True,
405
  )
@@ -413,7 +448,7 @@ download_data_from_dropbox()
413
  define_style()
414
 
415
  df, embeddings_unique = load_data_embeddings()
416
- logo(df["date"].max(), df.shape[0])
417
 
418
  # model = model_to_device()
419
 
@@ -474,7 +509,7 @@ if query:
474
  )
475
 
476
  # Prepare the results for plotting
477
- plot_data = {"Date": [], "Title": [], "Score": [], "DOI": [], "category": []}
478
 
479
  search_df = pd.DataFrame(results[0])
480
 
@@ -503,11 +538,14 @@ if query:
503
  plot_data["Score"].append(search_df["score"][index]) # type: ignore
504
  plot_data["DOI"].append(row["doi"])
505
  plot_data["category"].append(row["category"])
 
506
 
507
  #summary_text = summarize_abstract(row['abstract'])
508
 
509
  with st.expander(f"{index+1}\. {row['title']}"): # type: ignore
510
- st.markdown(f"**Score:** {entry['score']:.1f}")
 
 
511
  st.markdown(f"**Authors:** {row['authors']}")
512
  col1, col2 = st.columns(2)
513
  col2.markdown(f"**Category:** {row['category']}")
@@ -556,6 +594,7 @@ if query:
556
  x="Date",
557
  y="Score",
558
  hover_data=["Title", "DOI"],
 
559
  title="Publication Times and Scores",
560
  )
561
  fig.update_traces(marker=dict(size=10))
 
263
  @st.cache_resource(ttl="1d")
264
  def load_data_embeddings():
265
  existing_data_path = "aggregated_data"
266
+ new_data_directory_bio = "db_update"
267
  existing_embeddings_path = "biorxiv_ubin_embaddings.npy"
268
+ updated_embeddings_directory_bio = "embed_update"
269
 
270
+ new_data_directory_med = "db_update_med"
271
+ updated_embeddings_directory_med = "embed_update_med"
272
+
273
  # Load existing database and embeddings
274
  df_existing = pd.read_parquet(existing_data_path)
275
  embeddings_existing = np.load(existing_embeddings_path, allow_pickle=True)
276
 
277
+ print(f"Existing data shape: {df_existing.shape}, Existing embeddings shape: {embeddings_existing.shape}")
278
+
279
+ # Determine the embedding size from existing embeddings
280
+ embedding_size = embeddings_existing.shape[1]
281
+
282
  # Prepare lists to collect new updates
283
  df_updates_list = []
284
  embeddings_updates_list = []
285
 
286
+ # Helper function to process updates from a specified directory
287
+ def process_updates(new_data_directory, updated_embeddings_directory):
288
+ new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
289
+ print(new_data_files)
290
+ for data_file in new_data_files:
291
+ corresponding_embedding_file = Path(updated_embeddings_directory) / (
292
+ data_file.stem + ".npy"
293
+ )
294
+
295
+ if corresponding_embedding_file.exists():
296
+ df = pd.read_parquet(data_file)
297
+ new_embeddings = np.load(corresponding_embedding_file, allow_pickle=True)
298
+
299
+ # Check if the number of rows in the DataFrame matches the number of rows in the embeddings
300
+ if df.shape[0] != new_embeddings.shape[0]:
301
+ print(f"Shape mismatch for {data_file.name}: DataFrame has {df.shape[0]} rows, embeddings have {new_embeddings.shape[0]} rows. Skipping.")
302
+ continue
303
+
304
+ # Check embedding size and adjust if necessary
305
+ if new_embeddings.shape[1] != embedding_size:
306
+ print(f"Skipping {data_file.name} due to embedding size mismatch.")
307
+ continue
308
+
309
+ df_updates_list.append(df)
310
+ embeddings_updates_list.append(new_embeddings)
311
+ else:
312
+ print(f"No corresponding embedding file found for {data_file.name}")
313
 
314
+ # Process updates from both BioRxiv and MedArXiv
315
+ process_updates(new_data_directory_bio, updated_embeddings_directory_bio)
316
+ process_updates(new_data_directory_med, updated_embeddings_directory_med)
 
 
 
317
 
318
  # Concatenate all updates
319
  if df_updates_list:
 
329
  # Append new data to existing, handling duplicates as needed
330
  df_combined = pd.concat([df_existing, df_updates])
331
 
332
+ # Create a mask for filtering
333
  mask = ~df_combined.duplicated(subset=["title"], keep="last")
334
  df_combined = df_combined[mask]
335
 
 
340
  else embeddings_existing
341
  )
342
 
343
+ # Filter the embeddings based on the dataframe unique entries
344
  embeddings_combined = embeddings_combined[mask]
345
 
346
  return df_combined, embeddings_combined
347
 
348
+
349
  LLM_prompt = "Review the abstracts listed below and create a list and summary that captures their main themes and findings. Identify any commonalities across the abstracts and highlight these in your summary. Ensure your response is concise, avoids external links, and is formatted in markdown.\n\n"
350
 
351
+ def summarize_abstract(abstract, llm_model="llama-3.1-70b-versatile", instructions=LLM_prompt, api_key=st.secrets["groq_token"]):
352
  """
353
  Summarizes the provided abstract using a specified LLM model.
354
 
355
  Parameters:
356
  - abstract (str): The abstract text to be summarized.
357
+ - llm_model (str): The LLM model used for summarization. Defaults to "llama-3.1-70b-versatile".
358
 
359
  Returns:
360
  - str: A summary of the abstract, condensed into one to two sentences.
 
416
  )
417
 
418
 
419
+ def logo(db_update_date, db_size_bio, db_size_med):
420
  # Initialize Streamlit app
421
+ biorxiv_logo = "https://www.biorxiv.org/sites/default/files/biorxiv_logo_homepage.png"
422
+ medarxiv_logo = "https://www.medrxiv.org/sites/default/files/medRxiv_homepage_logo.png"
423
  st.markdown(
424
  f"""
425
+ <div style='display: flex; justify-content: center; align-items: center;'>
426
+ <div style='margin-right: 20px;'>
427
+ <img src='{biorxiv_logo}' alt='BioRxiv logo' style='max-height: 100px;'>
428
+ </div>
429
+ <div style='margin-left: 20px;'>
430
+ <img src='{medarxiv_logo}' alt='medRxiv logo' style='max-height: 100px;'>
431
+ </div>
432
  </div>
433
+ <div style='text-align: center; margin-top: 10px;'>
434
+ <h3 style='color: black;'>Manuscript Semantic Search [bMSS]</h3>
435
+ Last database update: {db_update_date}; Database size: bioRxiv: {db_size_bio} / medRxiv: {db_size_med} entries
436
+ </div>
437
+ <br>
438
  """,
439
  unsafe_allow_html=True,
440
  )
 
448
  define_style()
449
 
450
  df, embeddings_unique = load_data_embeddings()
451
+ logo(df["date"].max(), df[df['server']=='biorxiv'].shape[0], df[df['server']=='medrxiv'].shape[0])
452
 
453
  # model = model_to_device()
454
 
 
509
  )
510
 
511
  # Prepare the results for plotting
512
+ plot_data = {"Date": [], "Title": [], "Score": [], "DOI": [], "category": [], "server": []}
513
 
514
  search_df = pd.DataFrame(results[0])
515
 
 
538
  plot_data["Score"].append(search_df["score"][index]) # type: ignore
539
  plot_data["DOI"].append(row["doi"])
540
  plot_data["category"].append(row["category"])
541
+ plot_data["server"].append(row["server"])
542
 
543
  #summary_text = summarize_abstract(row['abstract'])
544
 
545
  with st.expander(f"{index+1}\. {row['title']}"): # type: ignore
546
+ col1, col2 = st.columns(2)
547
+ col1.markdown(f"**Score:** {entry['score']:.1f}")
548
+ col2.markdown(f"**Server:** [{row['server']}]")
549
  st.markdown(f"**Authors:** {row['authors']}")
550
  col1, col2 = st.columns(2)
551
  col2.markdown(f"**Category:** {row['category']}")
 
594
  x="Date",
595
  y="Score",
596
  hover_data=["Title", "DOI"],
597
+ color='server',
598
  title="Publication Times and Scores",
599
  )
600
  fig.update_traces(marker=dict(size=10))
update_database_medarxiv.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from pathlib import Path
4
+ import datetime
5
+ import requests
6
+ import json
7
+ import os
8
+ from datetime import datetime
9
+ from dateutil.relativedelta import relativedelta
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from sentence_transformers import SentenceTransformer
12
+ import torch
13
+ import shutil
14
+ import dropbox
15
+ import streamlit as st
16
+ import time
17
+
18
+ def retry_on_exception(exception, retries=5, delay=2):
19
+ def decorator(func):
20
+ def wrapper(*args, **kwargs):
21
+ last_exception = None
22
+ for _ in range(retries):
23
+ try:
24
+ return func(*args, **kwargs)
25
+ except exception as e:
26
+ last_exception = e
27
+ print(f"Retrying due to: {str(e)}")
28
+ time.sleep(delay)
29
+ raise last_exception
30
+ return wrapper
31
+ return decorator
32
+
33
+ @retry_on_exception(requests.exceptions.ConnectionError)
34
+ def fetch_and_save_data_block(endpoint, server, block_start, block_end, save_directory, format='json'):
35
+ base_url = f"https://api.medrxiv.org/details/{server}/"
36
+ block_interval = f"{block_start.strftime('%Y-%m-%d')}/{block_end.strftime('%Y-%m-%d')}"
37
+ block_data = []
38
+ cursor = 0
39
+ continue_fetching = True
40
+
41
+ while continue_fetching:
42
+ url = f"{base_url}{block_interval}/{cursor}/{format}"
43
+ response = requests.get(url)
44
+
45
+ if response.status_code != 200:
46
+ print(f"Failed to fetch data for block {block_interval} at cursor {cursor}. HTTP Status: {response.status_code}")
47
+ break
48
+
49
+ data = response.json()
50
+ fetched_papers = len(data['collection'])
51
+
52
+ if fetched_papers > 0:
53
+ block_data.extend(data['collection'])
54
+ cursor += fetched_papers
55
+ print(f"Fetched {fetched_papers} papers for block {block_interval}. Total fetched: {cursor}.")
56
+ else:
57
+ continue_fetching = False
58
+
59
+ if block_data:
60
+ save_data_block(block_data, block_start, block_end, endpoint, save_directory)
61
+
62
+ def save_data_block(block_data, start_date, end_date, endpoint, save_directory):
63
+ start_yymmdd = start_date.strftime("%y%m%d")
64
+ end_yymmdd = end_date.strftime("%y%m%d")
65
+ filename = f"{save_directory}/{endpoint}_data_{start_yymmdd}_{end_yymmdd}.json"
66
+
67
+ with open(filename, 'w') as file:
68
+ json.dump(block_data, file, indent=4)
69
+
70
+ print(f"Saved data block to {filename}")
71
+
72
+ def fetch_data(endpoint, server, interval, save_directory, format='json'):
73
+ os.makedirs(save_directory, exist_ok=True)
74
+ start_date, end_date = [datetime.strptime(date, "%Y-%m-%d") for date in interval.split('/')]
75
+ current_date = start_date
76
+ tasks = []
77
+
78
+ with ThreadPoolExecutor(max_workers=12) as executor:
79
+ while current_date <= end_date:
80
+ block_start = current_date
81
+ block_end = min(current_date + relativedelta(months=1) - relativedelta(days=1), end_date)
82
+ tasks.append(executor.submit(fetch_and_save_data_block, endpoint, server, block_start, block_end, save_directory, format))
83
+ current_date += relativedelta(months=1)
84
+
85
+ for future in as_completed(tasks):
86
+ future.result()
87
+
88
+ def load_json_to_dataframe(json_file):
89
+ with open(json_file, 'r') as file:
90
+ data = json.load(file)
91
+ return pd.DataFrame(data)
92
+
93
+ def save_dataframe(df, save_path):
94
+ df.to_parquet(save_path)
95
+
96
+ def process_json_files(directory, save_directory):
97
+ os.makedirs(save_directory, exist_ok=True)
98
+
99
+ json_files = list(Path(directory).glob('*.json'))
100
+ print(f'json_files {type(json_files)}: {json_files}')
101
+
102
+ for json_file in json_files:
103
+ df = load_json_to_dataframe(json_file)
104
+
105
+ parquet_filename = f"{json_file.stem}.parquet"
106
+ save_path = os.path.join(save_directory, parquet_filename)
107
+
108
+ if os.path.exists(save_path):
109
+ npy_file_path = save_path.replace('db_update', 'embed_update').replace('parquet', 'npy')
110
+ if os.path.exists(npy_file_path):
111
+ os.remove(npy_file_path)
112
+ print(f'Removed embedding file {npy_file_path} due to the dataframe update')
113
+
114
+ save_dataframe(df, save_path)
115
+ print(f"Processed and saved {json_file.name} to {parquet_filename}")
116
+
117
+ def load_unprocessed_parquets(db_update_directory, embed_update_directory):
118
+ db_update_directory = Path(db_update_directory)
119
+ embed_update_directory = Path(embed_update_directory)
120
+
121
+ parquet_files = list(db_update_directory.glob('*.parquet'))
122
+ npy_files = {f.stem for f in embed_update_directory.glob('*.npy')}
123
+ unprocessed_dataframes = []
124
+
125
+ for parquet_file in parquet_files:
126
+ if parquet_file.stem not in npy_files:
127
+ unprocessed_dataframes.append(parquet_file)
128
+ print(f"Loaded unprocessed Parquet file: {parquet_file.name}")
129
+ else:
130
+ print(f"Skipping processed Parquet file: {parquet_file.name}")
131
+
132
+ return unprocessed_dataframes
133
+
134
+ def connect_to_dropbox():
135
+ dropbox_APP_KEY = st.secrets["dropbox_APP_KEY"]
136
+ dropbox_APP_SECRET = st.secrets["dropbox_APP_SECRET"]
137
+ dropbox_REFRESH_TOKEN = st.secrets["dropbox_REFRESH_TOKEN"]
138
+
139
+ dbx = dropbox.Dropbox(
140
+ app_key=dropbox_APP_KEY,
141
+ app_secret=dropbox_APP_SECRET,
142
+ oauth2_refresh_token=dropbox_REFRESH_TOKEN
143
+ )
144
+ return dbx
145
+
146
+ def upload_path(local_path, dropbox_path):
147
+ dbx = connect_to_dropbox()
148
+ local_path = Path(local_path)
149
+
150
+ if local_path.is_file():
151
+ relative_path = local_path.name
152
+ dropbox_file_path = os.path.join(dropbox_path, relative_path).replace('\\', '/').replace('//', '/')
153
+ upload_file(local_path, dropbox_file_path, dbx)
154
+ elif local_path.is_dir():
155
+ for local_file in local_path.rglob('*'):
156
+ if local_file.is_file():
157
+ relative_path = local_file.relative_to(local_path.parent)
158
+ dropbox_file_path = os.path.join(dropbox_path, relative_path).replace('\\', '/').replace('//', '/')
159
+ upload_file(local_file, dropbox_file_path, dbx)
160
+ else:
161
+ print("The provided path does not exist.")
162
+
163
+ def upload_file(file_path, dropbox_file_path, dbx):
164
+ try:
165
+ dropbox_file_path = dropbox_file_path.replace('\\', '/')
166
+
167
+ try:
168
+ metadata = dbx.files_get_metadata(dropbox_file_path)
169
+ dropbox_mod_time = metadata.server_modified
170
+ local_mod_time = datetime.fromtimestamp(file_path.stat().st_mtime)
171
+
172
+ if dropbox_mod_time >= local_mod_time:
173
+ print(f"Skipped {dropbox_file_path}, Dropbox version is up-to-date.")
174
+ return
175
+ except dropbox.exceptions.ApiError as e:
176
+ if not isinstance(e.error, dropbox.files.GetMetadataError) or e.error.is_path() and e.error.get_path().is_not_found():
177
+ print(f"No existing file on Dropbox, proceeding with upload: {dropbox_file_path}")
178
+ else:
179
+ raise e
180
+
181
+ with file_path.open('rb') as f:
182
+ dbx.files_upload(f.read(), dropbox_file_path, mode=dropbox.files.WriteMode.overwrite)
183
+ print(f"Uploaded {dropbox_file_path}")
184
+ except Exception as e:
185
+ print(f"Failed to upload {dropbox_file_path}: {str(e)}")
186
+
187
+ def load_data_embeddings():
188
+ new_data_directory = "db_update_med"
189
+ updated_embeddings_directory = "embed_update_med"
190
+ new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
191
+
192
+ df_updates_list = []
193
+ embeddings_updates_list = []
194
+
195
+ for data_file in new_data_files:
196
+ # Assuming naming convention allows direct correlation
197
+ corresponding_embedding_file = Path(updated_embeddings_directory) / (
198
+ data_file.stem + ".npy"
199
+ )
200
+
201
+ if corresponding_embedding_file.exists():
202
+ # Load and append DataFrame and embeddings
203
+ df_updates_list.append(pd.read_parquet(data_file))
204
+ embeddings_updates_list.append(np.load(corresponding_embedding_file))
205
+ else:
206
+ print(f"No corresponding embedding file found for {data_file.name}")
207
+
208
+ new_data_files = sorted(Path(new_data_directory).glob("*.parquet"))
209
+ for data_file in new_data_files:
210
+ corresponding_embedding_file = Path(updated_embeddings_directory) / (
211
+ data_file.stem + ".npy"
212
+ )
213
+
214
+ if corresponding_embedding_file.exists():
215
+ df_updates_list.append(pd.read_parquet(data_file))
216
+ embeddings_updates_list.append(np.load(corresponding_embedding_file))
217
+ else:
218
+ print(f"No corresponding embedding file found for {data_file.name}")
219
+
220
+ if df_updates_list:
221
+ df_updates = pd.concat(df_updates_list)
222
+ else:
223
+ df_updates = pd.DataFrame()
224
+
225
+ if embeddings_updates_list:
226
+ embeddings_updates = np.vstack(embeddings_updates_list)
227
+ else:
228
+ embeddings_updates = np.array([])
229
+
230
+ df_combined = df_updates
231
+ mask = ~df_combined.duplicated(subset=["title"], keep="last")
232
+ df_combined = df_combined[mask]
233
+
234
+ embeddings_combined = embeddings_updates
235
+
236
+ embeddings_combined = embeddings_combined[mask]
237
+
238
+ return df_combined, embeddings_combined
239
+
240
+ endpoint = "details"
241
+ server = "medrxiv"
242
+
243
+ df, embeddings = load_data_embeddings()
244
+
245
+ try:
246
+ start_date = df['date'].max()
247
+ except:
248
+ start_date = '1990-01-01'
249
+ last_date = datetime.today().strftime('%Y-%m-%d')
250
+
251
+ interval = f'{start_date}/{last_date}'
252
+ print(f'using interval: {interval}')
253
+
254
+ save_directory = "db_update_json_med"
255
+ fetch_data(endpoint, server, interval, save_directory)
256
+
257
+ directory = r'db_update_json_med'
258
+ save_directory = r'db_update_med'
259
+ process_json_files(directory, save_directory)
260
+
261
+ db_update_directory = 'db_update_med'
262
+ embed_update_directory = 'embed_update_med'
263
+ unprocessed_dataframes = load_unprocessed_parquets(db_update_directory, embed_update_directory)
264
+
265
+ if unprocessed_dataframes:
266
+ for file in unprocessed_dataframes:
267
+ df = pd.read_parquet(file)
268
+ query = df['abstract'].tolist()
269
+
270
+ device = "cuda" if torch.cuda.is_available() else "cpu"
271
+ model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
272
+ model.to(device)
273
+
274
+ query_embedding = model.encode(query, normalize_embeddings=True, precision='ubinary', show_progress_bar=True)
275
+ file_path = os.path.basename(file).split('.')[0]
276
+ os.makedirs('embed_update_med', exist_ok=True)
277
+ embeddings_path = f'embed_update_med/{file_path}'
278
+ np.save(embeddings_path, query_embedding)
279
+ print(f'Saved embeddings {embeddings_path}')
280
+
281
+ db_update_json = 'db_update_json_med'
282
+ shutil.rmtree(db_update_json)
283
+ print(f"Directory '{db_update_json}' and its contents have been removed.")
284
+
285
+ for path in ['db_update_med', 'embed_update_med']:
286
+ upload_path(path, '/')
287
+
288
+ else:
289
+ print('Nothing to do')