ICLR2024-papers / papers.py
hysts's picture
hysts HF staff
Update
b00b9f1
raw
history blame
No virus
5.57 kB
import operator
import datasets
import pandas as pd
from huggingface_hub import HfApi
from ragatouille import RAGPretrainedModel
api = HfApi()
INDEX_DIR_PATH = ".ragatouille/colbert/indexes/ICLR2024-papers-abstract-index/"
api.snapshot_download(
repo_id="ICLR2024/ICLR2024-papers-abstract-index",
repo_type="dataset",
local_dir=INDEX_DIR_PATH,
)
ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH)
# Run once to initialize the retriever
ABSTRACT_RETRIEVER.search("LLM")
class PaperList:
COLUMN_INFO = [
["Title", "str"],
["Authors", "str"],
["Type", "str"],
["Paper page", "markdown"],
["OpenReview", "markdown"],
["GitHub", "markdown"],
["Spaces", "markdown"],
["Models", "markdown"],
["Datasets", "markdown"],
["claimed", "markdown"],
]
def __init__(self):
self.df_raw = self.get_df()
self.df_prettified = self.prettify(self.df_raw)
@staticmethod
def get_df() -> pd.DataFrame:
df = pd.merge(
left=datasets.load_dataset("ICLR2024/ICLR2024-papers", split="train").to_pandas(),
right=datasets.load_dataset("ICLR2024/ICLR2024-num-claimed-papers", split="train").to_pandas(),
on="id",
how="left",
)
df[["n_authors", "n_linked_authors"]] = df[["n_authors", "n_linked_authors"]].fillna(-1).astype(int)
df["paper_page"] = df["arxiv_id"].apply(
lambda arxiv_id: f"https://huggingface.co/papers/{arxiv_id}" if arxiv_id else ""
)
return df
@staticmethod
def create_link(text: str, url: str) -> str:
return f'<a href="{url}" target="_blank">{text}</a>'
@staticmethod
def prettify(df: pd.DataFrame) -> pd.DataFrame:
rows = []
for _, row in df.iterrows():
author_linked = "βœ…" if row.n_linked_authors > 0 else ""
n_linked_authors = "" if row.n_linked_authors == -1 else row.n_linked_authors
n_authors = "" if row.n_authors == -1 else row.n_authors
claimed_paper = "" if n_linked_authors == "" else f"{n_linked_authors}/{n_authors} {author_linked}"
new_row = {
"Title": row["title"],
"Authors": ", ".join(row["authors"]),
"Type": row["type"],
"Paper page": PaperList.create_link(row["arxiv_id"], row["paper_page"]),
"OpenReview": PaperList.create_link("OpenReview", row["OpenReview"]),
"GitHub": "\n".join([PaperList.create_link("GitHub", url) for url in row["GitHub"]]),
"Spaces": "\n".join(
[
PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}")
for repo_id in row["Space"]
]
),
"Models": "\n".join(
[PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in row["Model"]]
),
"Datasets": "\n".join(
[
PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}")
for repo_id in row["Dataset"]
]
),
"claimed": claimed_paper,
}
rows.append(new_row)
return pd.DataFrame(rows, columns=PaperList.get_column_names())
@staticmethod
def get_column_names():
return list(map(operator.itemgetter(0), PaperList.COLUMN_INFO))
def get_column_datatypes(self, column_names: list[str]) -> list[str]:
mapping = dict(self.COLUMN_INFO)
return [mapping[name] for name in column_names]
def search(
self,
title_search_query: str,
abstract_search_query: str,
max_num_to_retrieve: int,
filter_names: list[str],
presentation_type: str,
columns_names: list[str],
) -> pd.DataFrame:
df = self.df_raw.copy()
# As ragatouille uses str for document_id
df["id"] = df["id"].astype(str)
# Filter by title
df = df[df["title"].str.contains(title_search_query, case=False)]
# Filter by presentation type
if presentation_type != "(ALL)":
df = df[df["type"] == presentation_type]
if "Paper page" in filter_names:
df = df[df["paper_page"].notnull()]
if "GitHub" in filter_names:
df = df[df["GitHub"].apply(len) > 0]
if "Space" in filter_names:
df = df[df["Space"].apply(len) > 0]
if "Model" in filter_names:
df = df[df["Model"].apply(len) > 0]
if "Dataset" in filter_names:
df = df[df["Dataset"].apply(len) > 0]
# Filter by abstract
if abstract_search_query:
results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve)
remaining_ids = set(map(str, df["id"]))
found_id_set = set()
found_ids = []
for x in results:
paper_id = x["document_id"]
if paper_id not in remaining_ids:
continue
if paper_id in found_id_set:
continue
found_id_set.add(paper_id)
found_ids.append(paper_id)
df = df[df["id"].isin(found_ids)].set_index("id").reindex(index=found_ids).reset_index()
df_prettified = self.prettify(df)
return df_prettified.loc[:, columns_names]