ICLR2024-papers / app.py
hysts's picture
hysts HF staff
Update
3629665
raw
history blame
3.78 kB
#!/usr/bin/env python
import gradio as gr
import pandas as pd
from papers import PaperList
DESCRIPTION = "# ICLR 2024 Papers"
paper_list = PaperList()
DEFAULT_COLUMNS = [
"Title",
"Type",
"Paper page",
"OpenReview",
"GitHub",
"Spaces",
]
def update_num_papers(df: pd.DataFrame) -> str:
if "claimed" in df.columns:
return f"{len(df)} / {len(paper_list.df_raw)} ({len(df[df['claimed'].str.contains('βœ…')])} claimed)"
else:
return f"{len(df)} / {len(paper_list.df_raw)}"
def update_df(
title_search_query: str,
abstract_search_query: str,
max_num_to_retrieve: int,
filter_names: list,
presentation_type: str,
column_names: list[str],
) -> pd.DataFrame:
return gr.DataFrame(
value=paper_list.search(
title_search_query,
abstract_search_query,
max_num_to_retrieve,
filter_names,
presentation_type,
column_names,
),
datatype=paper_list.get_column_datatypes(column_names),
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
search_title = gr.Textbox(label="Search title")
with gr.Row():
with gr.Column(scale=4):
search_abstract = gr.Textbox(
label="Search abstract",
info="The result may not be accurate as the abstract does not contain all the information.",
)
with gr.Column(scale=1):
max_num_to_retrieve = gr.Slider(
label="Max number to retrieve",
info="This is used only for search on abstracts.",
minimum=1,
maximum=len(paper_list.df_raw),
step=1,
value=100,
)
filter_names = gr.CheckboxGroup(
label="Filter",
choices=[
"Paper page",
"GitHub",
"Space",
"Model",
"Dataset",
],
)
presentation_type = gr.Radio(
label="Presentation Type",
choices=["(ALL)", "Oral", "Spotlight Poster", "Poster"],
value="(ALL)",
)
column_names = gr.CheckboxGroup(label="Columns", choices=paper_list.get_column_names(), value=DEFAULT_COLUMNS)
num_papers = gr.Textbox(
label="Number of papers", value=update_num_papers(paper_list.df_prettified), interactive=False
)
df = gr.Dataframe(
value=paper_list.df_prettified,
datatype=paper_list.get_column_datatypes(paper_list.get_column_names()),
type="pandas",
row_count=(0, "dynamic"),
interactive=False,
height=1000,
elem_id="table",
wrap=True,
)
inputs = [
search_title,
search_abstract,
max_num_to_retrieve,
filter_names,
presentation_type,
column_names,
]
gr.on(
triggers=[
search_title.submit,
search_abstract.submit,
filter_names.input,
presentation_type.input,
column_names.input,
],
fn=update_df,
inputs=inputs,
outputs=df,
api_name=False,
).then(
fn=update_num_papers,
inputs=df,
outputs=num_papers,
queue=False,
api_name=False,
)
demo.load(
fn=update_df,
inputs=inputs,
outputs=df,
api_name=False,
).then(
fn=update_num_papers,
inputs=df,
outputs=num_papers,
queue=False,
api_name=False,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False)