Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 5,573 Bytes
25c0a98 b00b9f1 25c0a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import operator
import datasets
import pandas as pd
from huggingface_hub import HfApi
from ragatouille import RAGPretrainedModel
api = HfApi()
INDEX_DIR_PATH = ".ragatouille/colbert/indexes/ICLR2024-papers-abstract-index/"
api.snapshot_download(
repo_id="ICLR2024/ICLR2024-papers-abstract-index",
repo_type="dataset",
local_dir=INDEX_DIR_PATH,
)
ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH)
# Run once to initialize the retriever
ABSTRACT_RETRIEVER.search("LLM")
class PaperList:
COLUMN_INFO = [
["Title", "str"],
["Authors", "str"],
["Type", "str"],
["Paper page", "markdown"],
["OpenReview", "markdown"],
["GitHub", "markdown"],
["Spaces", "markdown"],
["Models", "markdown"],
["Datasets", "markdown"],
["claimed", "markdown"],
]
def __init__(self):
self.df_raw = self.get_df()
self.df_prettified = self.prettify(self.df_raw)
@staticmethod
def get_df() -> pd.DataFrame:
df = pd.merge(
left=datasets.load_dataset("ICLR2024/ICLR2024-papers", split="train").to_pandas(),
right=datasets.load_dataset("ICLR2024/ICLR2024-num-claimed-papers", split="train").to_pandas(),
on="id",
how="left",
)
df[["n_authors", "n_linked_authors"]] = df[["n_authors", "n_linked_authors"]].fillna(-1).astype(int)
df["paper_page"] = df["arxiv_id"].apply(
lambda arxiv_id: f"https://huggingface.co/papers/{arxiv_id}" if arxiv_id else ""
)
return df
@staticmethod
def create_link(text: str, url: str) -> str:
return f'<a href="{url}" target="_blank">{text}</a>'
@staticmethod
def prettify(df: pd.DataFrame) -> pd.DataFrame:
rows = []
for _, row in df.iterrows():
author_linked = "✅" if row.n_linked_authors > 0 else ""
n_linked_authors = "" if row.n_linked_authors == -1 else row.n_linked_authors
n_authors = "" if row.n_authors == -1 else row.n_authors
claimed_paper = "" if n_linked_authors == "" else f"{n_linked_authors}/{n_authors} {author_linked}"
new_row = {
"Title": row["title"],
"Authors": ", ".join(row["authors"]),
"Type": row["type"],
"Paper page": PaperList.create_link(row["arxiv_id"], row["paper_page"]),
"OpenReview": PaperList.create_link("OpenReview", row["OpenReview"]),
"GitHub": "\n".join([PaperList.create_link("GitHub", url) for url in row["GitHub"]]),
"Spaces": "\n".join(
[
PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}")
for repo_id in row["Space"]
]
),
"Models": "\n".join(
[PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in row["Model"]]
),
"Datasets": "\n".join(
[
PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}")
for repo_id in row["Dataset"]
]
),
"claimed": claimed_paper,
}
rows.append(new_row)
return pd.DataFrame(rows, columns=PaperList.get_column_names())
@staticmethod
def get_column_names():
return list(map(operator.itemgetter(0), PaperList.COLUMN_INFO))
def get_column_datatypes(self, column_names: list[str]) -> list[str]:
mapping = dict(self.COLUMN_INFO)
return [mapping[name] for name in column_names]
def search(
self,
title_search_query: str,
abstract_search_query: str,
max_num_to_retrieve: int,
filter_names: list[str],
presentation_type: str,
columns_names: list[str],
) -> pd.DataFrame:
df = self.df_raw.copy()
# As ragatouille uses str for document_id
df["id"] = df["id"].astype(str)
# Filter by title
df = df[df["title"].str.contains(title_search_query, case=False)]
# Filter by presentation type
if presentation_type != "(ALL)":
df = df[df["type"] == presentation_type]
if "Paper page" in filter_names:
df = df[df["paper_page"].notnull()]
if "GitHub" in filter_names:
df = df[df["GitHub"].apply(len) > 0]
if "Space" in filter_names:
df = df[df["Space"].apply(len) > 0]
if "Model" in filter_names:
df = df[df["Model"].apply(len) > 0]
if "Dataset" in filter_names:
df = df[df["Dataset"].apply(len) > 0]
# Filter by abstract
if abstract_search_query:
results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve)
remaining_ids = set(map(str, df["id"]))
found_id_set = set()
found_ids = []
for x in results:
paper_id = x["document_id"]
if paper_id not in remaining_ids:
continue
if paper_id in found_id_set:
continue
found_id_set.add(paper_id)
found_ids.append(paper_id)
df = df[df["id"].isin(found_ids)].set_index("id").reindex(index=found_ids).reset_index()
df_prettified = self.prettify(df)
return df_prettified.loc[:, columns_names]
|