Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
import gradio as gr | |
import pandas as pd | |
from papers import PaperList | |
DESCRIPTION = "# ICLR 2024 Papers" | |
paper_list = PaperList() | |
DEFAULT_COLUMNS = [ | |
"Title", | |
"Type", | |
"Paper page", | |
"OpenReview", | |
"GitHub", | |
"Spaces", | |
] | |
def update_num_papers(df: pd.DataFrame) -> str: | |
if "claimed" in df.columns: | |
return f"{len(df)} / {len(paper_list.df_raw)} ({len(df[df['claimed'].str.contains('β ')])} claimed)" | |
else: | |
return f"{len(df)} / {len(paper_list.df_raw)}" | |
def update_df( | |
title_search_query: str, | |
abstract_search_query: str, | |
max_num_to_retrieve: int, | |
filter_names: list, | |
presentation_type: str, | |
column_names: list[str], | |
) -> pd.DataFrame: | |
return gr.DataFrame( | |
value=paper_list.search( | |
title_search_query, | |
abstract_search_query, | |
max_num_to_retrieve, | |
filter_names, | |
presentation_type, | |
column_names, | |
), | |
datatype=paper_list.get_column_datatypes(column_names), | |
) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Group(): | |
search_title = gr.Textbox(label="Search title") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
search_abstract = gr.Textbox( | |
label="Search abstract", | |
info="The result may not be accurate as the abstract does not contain all the information.", | |
) | |
with gr.Column(scale=1): | |
max_num_to_retrieve = gr.Slider( | |
label="Max number to retrieve", | |
info="This is used only for search on abstracts.", | |
minimum=1, | |
maximum=len(paper_list.df_raw), | |
step=1, | |
value=100, | |
) | |
filter_names = gr.CheckboxGroup( | |
label="Filter", | |
choices=[ | |
"Paper page", | |
"GitHub", | |
"Space", | |
"Model", | |
"Dataset", | |
], | |
) | |
presentation_type = gr.Radio( | |
label="Presentation Type", | |
choices=["(ALL)", "Oral", "Spotlight Poster", "Poster"], | |
value="(ALL)", | |
) | |
column_names = gr.CheckboxGroup(label="Columns", choices=paper_list.get_column_names(), value=DEFAULT_COLUMNS) | |
num_papers = gr.Textbox( | |
label="Number of papers", value=update_num_papers(paper_list.df_prettified), interactive=False | |
) | |
df = gr.Dataframe( | |
value=paper_list.df_prettified, | |
datatype=paper_list.get_column_datatypes(paper_list.get_column_names()), | |
type="pandas", | |
row_count=(0, "dynamic"), | |
interactive=False, | |
height=1000, | |
elem_id="table", | |
wrap=True, | |
) | |
inputs = [ | |
search_title, | |
search_abstract, | |
max_num_to_retrieve, | |
filter_names, | |
presentation_type, | |
column_names, | |
] | |
gr.on( | |
triggers=[ | |
search_title.submit, | |
search_abstract.submit, | |
filter_names.input, | |
presentation_type.input, | |
column_names.input, | |
], | |
fn=update_df, | |
inputs=inputs, | |
outputs=df, | |
api_name=False, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
api_name=False, | |
) | |
demo.load( | |
fn=update_df, | |
inputs=inputs, | |
outputs=df, | |
api_name=False, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
api_name=False, | |
) | |
if __name__ == "__main__": | |
demo.queue(api_open=False).launch(show_api=False) | |