import operator import datasets import pandas as pd from huggingface_hub import HfApi from ragatouille import RAGPretrainedModel api = HfApi() INDEX_DIR_PATH = ".ragatouille/colbert/indexes/ICLR2024-papers-abstract-index/" api.snapshot_download( repo_id="ICLR2024/ICLR2024-papers-abstract-index", repo_type="dataset", local_dir=INDEX_DIR_PATH, ) ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH) # Run once to initialize the retriever ABSTRACT_RETRIEVER.search("LLM") class PaperList: COLUMN_INFO = [ ["Title", "str"], ["Authors", "str"], ["Type", "str"], ["Paper page", "markdown"], ["OpenReview", "markdown"], ["GitHub", "markdown"], ["Spaces", "markdown"], ["Models", "markdown"], ["Datasets", "markdown"], ["claimed", "markdown"], ] def __init__(self): self.df_raw = self.get_df() self.df_prettified = self.prettify(self.df_raw) @staticmethod def get_df() -> pd.DataFrame: df = pd.merge( left=datasets.load_dataset("ICLR2024/ICLR2024-papers", split="train").to_pandas(), right=datasets.load_dataset("ICLR2024/ICLR2024-num-claimed-papers", split="train").to_pandas(), on="id", how="left", ) df[["n_authors", "n_linked_authors"]] = df[["n_authors", "n_linked_authors"]].fillna(-1).astype(int) df["paper_page"] = df["arxiv_id"].apply( lambda arxiv_id: f"https://huggingface.co/papers/{arxiv_id}" if arxiv_id else "" ) return df @staticmethod def create_link(text: str, url: str) -> str: return f'{text}' @staticmethod def prettify(df: pd.DataFrame) -> pd.DataFrame: rows = [] for _, row in df.iterrows(): author_linked = "✅" if row.n_linked_authors > 0 else "" n_linked_authors = "" if row.n_linked_authors == -1 else row.n_linked_authors n_authors = "" if row.n_authors == -1 else row.n_authors claimed_paper = "" if n_linked_authors == "" else f"{n_linked_authors}/{n_authors} {author_linked}" new_row = { "Title": row["title"], "Authors": ", ".join(row["authors"]), "Type": row["type"], "Paper page": PaperList.create_link(row["arxiv_id"], row["paper_page"]), "OpenReview": PaperList.create_link("OpenReview", row["OpenReview"]), "GitHub": "\n".join([PaperList.create_link("GitHub", url) for url in row["GitHub"]]), "Spaces": "\n".join( [ PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}") for repo_id in row["Space"] ] ), "Models": "\n".join( [PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in row["Model"]] ), "Datasets": "\n".join( [ PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}") for repo_id in row["Dataset"] ] ), "claimed": claimed_paper, } rows.append(new_row) return pd.DataFrame(rows, columns=PaperList.get_column_names()) @staticmethod def get_column_names(): return list(map(operator.itemgetter(0), PaperList.COLUMN_INFO)) def get_column_datatypes(self, column_names: list[str]) -> list[str]: mapping = dict(self.COLUMN_INFO) return [mapping[name] for name in column_names] def search( self, title_search_query: str, abstract_search_query: str, max_num_to_retrieve: int, filter_names: list[str], presentation_type: str, columns_names: list[str], ) -> pd.DataFrame: df = self.df_raw.copy() # As ragatouille uses str for document_id df["id"] = df["id"].astype(str) # Filter by title df = df[df["title"].str.contains(title_search_query, case=False)] # Filter by presentation type if presentation_type != "(ALL)": df = df[df["type"] == presentation_type] if "Paper page" in filter_names: df = df[df["paper_page"].notnull()] if "GitHub" in filter_names: df = df[df["GitHub"].apply(len) > 0] if "Space" in filter_names: df = df[df["Space"].apply(len) > 0] if "Model" in filter_names: df = df[df["Model"].apply(len) > 0] if "Dataset" in filter_names: df = df[df["Dataset"].apply(len) > 0] # Filter by abstract if abstract_search_query: results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve) remaining_ids = set(map(str, df["id"])) found_id_set = set() found_ids = [] for x in results: paper_id = x["document_id"] if paper_id not in remaining_ids: continue if paper_id in found_id_set: continue found_id_set.add(paper_id) found_ids.append(paper_id) df = df[df["id"].isin(found_ids)].set_index("id").reindex(index=found_ids).reset_index() df_prettified = self.prettify(df) return df_prettified.loc[:, columns_names]