#!/usr/bin/env python import gradio as gr import pandas as pd from papers import PaperList DESCRIPTION = "# ICLR 2024 Papers" paper_list = PaperList() DEFAULT_COLUMNS = [ "Title", "Type", "Paper page", "OpenReview", "GitHub", "Spaces", "Models", "Datasets", "claimed", ] def update_num_papers(df: pd.DataFrame) -> str: if "claimed" in df.columns: return f"{len(df)} / {len(paper_list.df_raw)} ({len(df[df['claimed'].str.contains('✅')])} claimed)" else: return f"{len(df)} / {len(paper_list.df_raw)}" def update_df( title_search_query: str, abstract_search_query: str, max_num_to_retrieve: int, filter_names: list, presentation_type: str, column_names: list[str], ) -> pd.DataFrame: return gr.DataFrame( value=paper_list.search( title_search_query, abstract_search_query, max_num_to_retrieve, filter_names, presentation_type, column_names, ), datatype=paper_list.get_column_datatypes(column_names), ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Group(): search_title = gr.Textbox(label="Search title") with gr.Row(): with gr.Column(scale=4): search_abstract = gr.Textbox( label="Search abstract", info="The result may not be accurate as the abstract does not contain all the information.", ) with gr.Column(scale=1): max_num_to_retrieve = gr.Slider( label="Max number to retrieve", info="This is used only for search on abstracts.", minimum=1, maximum=len(paper_list.df_raw), step=1, value=100, ) filter_names = gr.CheckboxGroup( label="Filter", choices=[ "Paper page", "GitHub", "Space", "Model", "Dataset", ], ) presentation_type = gr.Radio( label="Presentation Type", choices=["(ALL)", "Oral", "Spotlight Poster", "Poster"], value="(ALL)", ) column_names = gr.CheckboxGroup(label="Columns", choices=paper_list.get_column_names(), value=DEFAULT_COLUMNS) num_papers = gr.Textbox( label="Number of papers", value=update_num_papers(paper_list.df_prettified), interactive=False ) df = gr.Dataframe( value=paper_list.df_prettified, datatype=paper_list.get_column_datatypes(paper_list.get_column_names()), type="pandas", row_count=(0, "dynamic"), interactive=False, height=1000, elem_id="table", wrap=True, ) inputs = [ search_title, search_abstract, max_num_to_retrieve, filter_names, presentation_type, column_names, ] gr.on( triggers=[ search_title.submit, search_abstract.submit, filter_names.input, presentation_type.input, column_names.input, ], fn=update_df, inputs=inputs, outputs=df, api_name=False, ).then( fn=update_num_papers, inputs=df, outputs=num_papers, queue=False, api_name=False, ) demo.load( fn=update_df, inputs=inputs, outputs=df, api_name=False, ).then( fn=update_num_papers, inputs=df, outputs=num_papers, queue=False, api_name=False, ) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False)