File size: 3,794 Bytes
15ee8f1
 
 
25c0a98
15ee8f1
25c0a98
 
 
 
 
 
 
 
 
 
 
 
 
 
9a4b257
25c0a98
 
 
 
207e179
 
 
 
25c0a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15ee8f1
 
25c0a98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python

import gradio as gr
import pandas as pd

from papers import PaperList

DESCRIPTION = "# ICLR 2024 Papers"


paper_list = PaperList()

DEFAULT_COLUMNS = [
    "Title",
    "Type",
    "Paper page",
    "OpenReview",
    "GitHub",
    "Spaces",
    "Models",
]


def update_num_papers(df: pd.DataFrame) -> str:
    if "claimed" in df.columns:
        return f"{len(df)} / {len(paper_list.df_raw)} ({len(df[df['claimed'].str.contains('✅')])} claimed)"
    else:
        return f"{len(df)} / {len(paper_list.df_raw)}"


def update_df(
    title_search_query: str,
    abstract_search_query: str,
    max_num_to_retrieve: int,
    filter_names: list,
    presentation_type: str,
    column_names: list[str],
) -> pd.DataFrame:
    return gr.DataFrame(
        value=paper_list.search(
            title_search_query,
            abstract_search_query,
            max_num_to_retrieve,
            filter_names,
            presentation_type,
            column_names,
        ),
        datatype=paper_list.get_column_datatypes(column_names),
    )


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Group():
        search_title = gr.Textbox(label="Search title")
        with gr.Row():
            with gr.Column(scale=4):
                search_abstract = gr.Textbox(
                    label="Search abstract",
                    info="The result may not be accurate as the abstract does not contain all the information.",
                )
            with gr.Column(scale=1):
                max_num_to_retrieve = gr.Slider(
                    label="Max number to retrieve",
                    info="This is used only for search on abstracts.",
                    minimum=1,
                    maximum=len(paper_list.df_raw),
                    step=1,
                    value=100,
                )

        filter_names = gr.CheckboxGroup(
            label="Filter",
            choices=[
                "Paper page",
                "GitHub",
                "Space",
                "Model",
                "Dataset",
            ],
        )
        presentation_type = gr.Radio(
            label="Presentation Type",
            choices=["(ALL)", "Oral", "Spotlight Poster", "Poster"],
            value="(ALL)",
        )
        column_names = gr.CheckboxGroup(label="Columns", choices=paper_list.get_column_names(), value=DEFAULT_COLUMNS)

    num_papers = gr.Textbox(
        label="Number of papers", value=update_num_papers(paper_list.df_prettified), interactive=False
    )
    df = gr.Dataframe(
        value=paper_list.df_prettified,
        datatype=paper_list.get_column_datatypes(paper_list.get_column_names()),
        type="pandas",
        row_count=(0, "dynamic"),
        interactive=False,
        height=1000,
        elem_id="table",
        wrap=True,
    )

    inputs = [
        search_title,
        search_abstract,
        max_num_to_retrieve,
        filter_names,
        presentation_type,
        column_names,
    ]
    gr.on(
        triggers=[
            search_title.submit,
            search_abstract.submit,
            filter_names.input,
            presentation_type.input,
            column_names.input,
        ],
        fn=update_df,
        inputs=inputs,
        outputs=df,
        api_name=False,
    ).then(
        fn=update_num_papers,
        inputs=df,
        outputs=num_papers,
        queue=False,
        api_name=False,
    )
    demo.load(
        fn=update_df,
        inputs=inputs,
        outputs=df,
        api_name=False,
    ).then(
        fn=update_num_papers,
        inputs=df,
        outputs=num_papers,
        queue=False,
        api_name=False,
    )

if __name__ == "__main__":
    demo.queue(api_open=False).launch(show_api=False)