File size: 6,049 Bytes
25c0a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4307381
 
25c0a98
893a87f
25c0a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4307381
25c0a98
 
 
4307381
 
25c0a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4307381
 
25c0a98
 
 
 
 
 
b77561a
 
 
4307381
 
25c0a98
 
b00b9f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25c0a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c3197
25c0a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import operator

import datasets
import pandas as pd
from huggingface_hub import HfApi
from ragatouille import RAGPretrainedModel

api = HfApi()

INDEX_DIR_PATH = ".ragatouille/colbert/indexes/ICLR2024-papers-abstract-index/"
api.snapshot_download(
    repo_id="ICLR2024/ICLR2024-papers-abstract-index",
    repo_type="dataset",
    local_dir=INDEX_DIR_PATH,
)
ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH)
# Run once to initialize the retriever
ABSTRACT_RETRIEVER.search("LLM")


class PaperList:
    COLUMN_INFO = [
        ["Title", "str"],
        ["Authors", "str"],
        ["Type", "str"],
        ["Paper page", "markdown"],
        ["πŸ‘", "number"],
        ["πŸ’¬", "number"],
        ["OpenReview", "markdown"],
        ["Project page", "markdown"],
        ["GitHub", "markdown"],
        ["Spaces", "markdown"],
        ["Models", "markdown"],
        ["Datasets", "markdown"],
        ["claimed", "markdown"],
    ]

    def __init__(self):
        self.df_raw = self.get_df()
        self.df_prettified = self.prettify(self.df_raw)

    @staticmethod
    def get_df() -> pd.DataFrame:
        df = pd.merge(
            left=datasets.load_dataset("ICLR2024/ICLR2024-papers", split="train").to_pandas(),
            right=datasets.load_dataset("ICLR2024/ICLR2024-paper-stats", split="train").to_pandas(),
            on="id",
            how="left",
        )
        keys = ["n_authors", "n_linked_authors", "upvotes", "num_comments"]
        df[keys] = df[keys].fillna(-1).astype(int)
        df["paper_page"] = df["arxiv_id"].apply(
            lambda arxiv_id: f"https://huggingface.co/papers/{arxiv_id}" if arxiv_id else ""
        )
        return df

    @staticmethod
    def create_link(text: str, url: str) -> str:
        return f'<a href="{url}" target="_blank">{text}</a>'

    @staticmethod
    def prettify(df: pd.DataFrame) -> pd.DataFrame:
        rows = []
        for _, row in df.iterrows():
            author_linked = "βœ…" if row.n_linked_authors > 0 else ""
            n_linked_authors = "" if row.n_linked_authors == -1 else row.n_linked_authors
            n_authors = "" if row.n_authors == -1 else row.n_authors
            claimed_paper = "" if n_linked_authors == "" else f"{n_linked_authors}/{n_authors} {author_linked}"
            upvotes = "" if row.upvotes == -1 else row.upvotes
            num_comments = "" if row.num_comments == -1 else row.num_comments

            new_row = {
                "Title": row["title"],
                "Authors": ", ".join(row["authors"]),
                "Type": row["type"],
                "Paper page": PaperList.create_link(row["arxiv_id"], row["paper_page"]),
                "Project page": (
                    PaperList.create_link("Project page", row["project_page"]) if row["project_page"] else ""
                ),
                "πŸ‘": upvotes,
                "πŸ’¬": num_comments,
                "OpenReview": PaperList.create_link("OpenReview", row["OpenReview"]),
                "GitHub": "\n".join([PaperList.create_link("GitHub", url) for url in row["GitHub"]]),
                "Spaces": "\n".join(
                    [
                        PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}")
                        for repo_id in row["Space"]
                    ]
                ),
                "Models": "\n".join(
                    [PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in row["Model"]]
                ),
                "Datasets": "\n".join(
                    [
                        PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}")
                        for repo_id in row["Dataset"]
                    ]
                ),
                "claimed": claimed_paper,
            }
            rows.append(new_row)
        return pd.DataFrame(rows, columns=PaperList.get_column_names())

    @staticmethod
    def get_column_names():
        return list(map(operator.itemgetter(0), PaperList.COLUMN_INFO))

    def get_column_datatypes(self, column_names: list[str]) -> list[str]:
        mapping = dict(self.COLUMN_INFO)
        return [mapping[name] for name in column_names]

    def search(
        self,
        title_search_query: str,
        abstract_search_query: str,
        max_num_to_retrieve: int,
        filter_names: list[str],
        presentation_type: str,
        columns_names: list[str],
    ) -> pd.DataFrame:
        df = self.df_raw.copy()
        # As ragatouille uses str for document_id
        df["id"] = df["id"].astype(str)

        # Filter by title
        df = df[df["title"].str.contains(title_search_query, case=False)]

        # Filter by presentation type
        if presentation_type != "(ALL)":
            df = df[df["type"] == presentation_type]

        if "Paper page" in filter_names:
            df = df[df["paper_page"] != ""]
        if "GitHub" in filter_names:
            df = df[df["GitHub"].apply(len) > 0]
        if "Space" in filter_names:
            df = df[df["Space"].apply(len) > 0]
        if "Model" in filter_names:
            df = df[df["Model"].apply(len) > 0]
        if "Dataset" in filter_names:
            df = df[df["Dataset"].apply(len) > 0]

        # Filter by abstract
        if abstract_search_query:
            results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve)
            remaining_ids = set(map(str, df["id"]))
            found_id_set = set()
            found_ids = []
            for x in results:
                paper_id = x["document_id"]
                if paper_id not in remaining_ids:
                    continue
                if paper_id in found_id_set:
                    continue
                found_id_set.add(paper_id)
                found_ids.append(paper_id)
            df = df[df["id"].isin(found_ids)].set_index("id").reindex(index=found_ids).reset_index()

        df_prettified = self.prettify(df)
        return df_prettified.loc[:, columns_names]