daily-papers / papers.py
hysts's picture
hysts HF staff
Update
e332775
raw
history blame
4.92 kB
import datetime
import operator
import datasets
import pandas as pd
import tqdm.auto
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi
from ragatouille import RAGPretrainedModel
api = HfApi()
INDEX_REPO_ID = "hysts-bot-data/daily-papers-abstract-index"
INDEX_DIR_PATH = ".ragatouille/colbert/indexes/daily-papers-abstract-index/"
api.snapshot_download(
repo_id=INDEX_REPO_ID,
repo_type="dataset",
local_dir=INDEX_DIR_PATH,
)
abstract_retriever = RAGPretrainedModel.from_index(INDEX_DIR_PATH)
# Run once to initialize the retriever
abstract_retriever.search("LLM")
def update_abstract_index() -> None:
global abstract_retriever
api.snapshot_download(
repo_id=INDEX_REPO_ID,
repo_type="dataset",
local_dir=INDEX_DIR_PATH,
)
abstract_retriever = RAGPretrainedModel.from_index(INDEX_DIR_PATH)
abstract_retriever.search("LLM")
scheduler = BackgroundScheduler()
scheduler.add_job(func=update_abstract_index, trigger="cron", hour="*", timezone="UTC", misfire_grace_time=3 * 60)
scheduler.start()
def get_df() -> pd.DataFrame:
df = pd.merge(
left=datasets.load_dataset("hysts-bot-data/daily-papers", split="train").to_pandas(),
right=datasets.load_dataset("hysts-bot-data/daily-papers-stats", split="train").to_pandas(),
on="arxiv_id",
)
df = df[::-1].reset_index(drop=True)
df["date"] = df["date"].dt.strftime("%Y-%m-%d")
df = df.drop(columns=["authors", "abstract"])
paper_info = []
for _, row in tqdm.auto.tqdm(df.iterrows(), total=len(df)):
info = row.copy()
info["paper_page"] = f"https://huggingface.co/papers/{row.arxiv_id}"
paper_info.append(info)
return pd.DataFrame(paper_info)
class Prettifier:
@staticmethod
def get_github_link(link: str) -> str:
if not link:
return ""
return Prettifier.create_link("github", link)
@staticmethod
def create_link(text: str, url: str) -> str:
return f'<a href="{url}" target="_blank">{text}</a>'
@staticmethod
def to_div(text: str | None, category_name: str) -> str:
if text is None:
text = ""
class_name = f"{category_name}-{text.lower()}"
return f'<div class="{class_name}">{text}</div>'
def __call__(self, df: pd.DataFrame) -> pd.DataFrame:
new_rows = []
for _, row in df.iterrows():
new_row = {
"date": Prettifier.create_link(row.date, f"https://huggingface.co/papers?date={row.date}"),
"paper_page": Prettifier.create_link(row.arxiv_id, row.paper_page),
"title": row["title"],
"github": self.get_github_link(row.github),
"πŸ‘": row["upvotes"],
"πŸ’¬": row["num_comments"],
}
new_rows.append(new_row)
return pd.DataFrame(new_rows)
class PaperList:
COLUMN_INFO = [
["date", "markdown"],
["paper_page", "markdown"],
["title", "str"],
["github", "markdown"],
["πŸ‘", "number"],
["πŸ’¬", "number"],
]
def __init__(self, df: pd.DataFrame):
self.df_raw = df
self._prettifier = Prettifier()
self.df_prettified = self._prettifier(df).loc[:, self.column_names]
@property
def column_names(self):
return list(map(operator.itemgetter(0), self.COLUMN_INFO))
@property
def column_datatype(self):
return list(map(operator.itemgetter(1), self.COLUMN_INFO))
def search(
self,
start_date: datetime.datetime,
end_date: datetime.datetime,
title_search_query: str,
abstract_search_query: str,
max_num_to_retrieve: int,
) -> pd.DataFrame:
df = self.df_raw.copy()
df["date"] = pd.to_datetime(df["date"])
# Filter by date
df = df[(df["date"] >= start_date) & (df["date"] <= end_date)]
df["date"] = df["date"].dt.strftime("%Y-%m-%d")
# Filter by title
df = df[df["title"].str.contains(title_search_query, case=False)]
# Filter by abstract
if abstract_search_query:
results = abstract_retriever.search(abstract_search_query, k=max_num_to_retrieve)
remaining_ids = set(df["arxiv_id"])
found_id_set = set()
found_ids = []
for x in results:
arxiv_id = x["document_id"]
if arxiv_id not in remaining_ids:
continue
if arxiv_id in found_id_set:
continue
found_id_set.add(arxiv_id)
found_ids.append(arxiv_id)
df = df[df["arxiv_id"].isin(found_ids)].set_index("arxiv_id").reindex(index=found_ids).reset_index()
df_prettified = self._prettifier(df).loc[:, self.column_names]
return df_prettified