import yaml import gradio as gr import pandas as pd import numpy as np import altair as alt import plotly.express as px import pickle import os from src.assets.css_html_js import custom_css from src.assets.awesome_mapping import paper_mapping, section_mapping, bibtex_mapping, venue_mapping, citation_key_mapping TITLE = "🔥CNN Structured Pruning Leaderboard" PAPER_LINK = 'https://arxiv.org/abs/2303.00566' PAPER_LINK_IEEE = 'https://ieeexplore.ieee.org/document/10330640' AWESOME_PRUNING_LINK = 'https://github.com/he-y/Awesome-Pruning' BIBTEX = ''' @article{he2023structured, author={He, Yang and Xiao, Lingao}, journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, title={Structured Pruning for Deep Convolutional Neural Networks: A Survey}, year={2023}, volume={}, number={}, pages={1-20}, doi={10.1109/TPAMI.2023.3334614}} ''' INTRO = f""" Welcome to our dedicated site for the survey paper: "[Structured Pruning for Deep Convolutional Neural Networks: A Survey]({PAPER_LINK})". Our survey is accepted by IEEE T-PAMI. Links include [arXiv]({PAPER_LINK}) and [IEEE Xplore]({PAPER_LINK_IEEE}). Github Repo: [Awesome Pruning: A curated list of neural network pruning resources]({AWESOME_PRUNING_LINK}). This platform serves as a repository and visual representation of the benchmarks from studies covered in our survey. Here, you can explore the reported accuracy and FLOPs metrics from various papers, providing an at-a-glance view of the advancements and methodologies in the domain of structured pruning. If you find this website helpful, please consider citing our paper 😊 """ COLS_KEEP = ['sec', 'year', 'method', 'model', 'acc', 'acc-pruned', 'acc-change', 'flops-pruned', 'flops-drop', 'param-pruned', 'param-drop', 'dataset'] COLS = ['sec', 'year', 'method', 'model', 'acc', 'acc-pruned', 'acc-change', 'flops', 'flops-pruned', 'flops-drop', 'param', 'param-pruned', 'param-drop', 'dataset'] MISC_GROUP = ['model', 'dataset', 'method', 'year', 'sec'] ACC_GROUP = ['acc', 'acc-pruned', 'acc-change'] FLOPS_GROUP = ['flops', 'flops-pruned', 'flops-drop'] PARAM_GROUP = ['param', 'param-pruned', 'param-drop'] # Define a mapping from original headers to custom headers CUSTOM_HEADER_MAP = { 'sec': 'Section', 'year': 'Year', 'method': 'Method', 'model': 'Model', 'acc': 'Acc', 'acc-pruned': 'Acc Pruned', # 'acc-change': 'Acc. Δ (%)', 'acc-change': 'Acc ↓ (%)', 'flops': 'FLOPs (M)', 'flops-pruned': 'FLOPs Pruned (M)', 'flops-drop': 'FLOPs ↓ (%)', 'param': 'Params (M)', 'param-pruned': 'Params Pruned (M)', 'param-drop': 'Params ↓ (%)', 'dataset': 'Dataset' } CUSTOM_HEADER_MAP.update({v: k for k, v in CUSTOM_HEADER_MAP.items()}) df = pickle.load(open("src/assets/data.pkl", "rb")) baseline = pickle.load(open("src/assets/baseline.pkl", "rb")) def filter_table_combined(leaderboard, search_box, search_box_method, search_box_year, search_box_section, acc_base_box, acc_box, acc_change, flops_base_box, flops_box, flops_drop, param_base_box, param_box, param_drop): search_boxes = [search_box, search_box_method, search_box_year, search_box_section, acc_base_box, acc_box, acc_change, flops_base_box, flops_box, flops_drop, param_base_box, param_box, param_drop] column_names = ['model', 'method', 'year', 'sec', 'acc', 'acc-pruned', 'acc-change', 'flops', 'flops-pruned', 'flops-drop', 'param', 'param-pruned', 'param-drop'] filtered_df = leaderboard.copy() for idx, (q, col_name) in enumerate(zip(search_boxes, column_names)): if q != '': if idx == 3: # Special case for section if q[0] != '2': # Does not start with 2 q = "2." + q[0] elif len(q) < 5: filtered_df = filtered_df[filtered_df[col_name].str.slice(0, len(q)).str.lower() == q.strip().lower()] else: filtered_df = filtered_df[filtered_df[col_name].astype(str).str.lower() == q.strip().lower()] elif idx < 4: # Similar to original filter_table filtered_df = filtered_df[filtered_df[col_name].astype(str).str.contains(q, case=False)] else: # Similar to original filter_table_by_acc filtered_df[col_name].replace('', np.nan, inplace=True) filtered_df.dropna(subset=[col_name], inplace=True) if idx in [4, 5, 9, 12]: filtered_df = filtered_df[filtered_df[col_name].astype(float) > float(q)] else: filtered_df = filtered_df[filtered_df[col_name].astype(float) < float(q)] return filtered_df # Function to update columns def update_columns(leaderboard, columns: list): return leaderboard[leaderboard.columns.intersection(columns)].rename(columns=CUSTOM_HEADER_MAP) def update_table(leaderboard, search_box, search_box_method, search_box_year, search_box_section, acc_base_box, acc_box, acc_change, flops_base_box, flops_box, flops_drop, param_base_box, param_box, param_drop): updated_df = filter_table_combined(leaderboard, search_box, search_box_method, search_box_year, search_box_section, acc_base_box, acc_box, acc_change, flops_base_box, flops_box, flops_drop, param_base_box, param_box, param_drop) updated_df = update_columns(updated_df, COLS) return updated_df def update_text(x): return CUSTOM_HEADER_MAP[x] def get_shown_columns(misc_checkbox_group, acc_checkbox_group, flops_checkbox_group, param_checkbox_group): # return all columns if all checkbox groups are selected updated_columns = [CUSTOM_HEADER_MAP[col] for col in misc_checkbox_group + acc_checkbox_group + flops_checkbox_group + param_checkbox_group] print("Columns updated to", updated_columns, "\n") return updated_columns def make_plot(data, y_axis='acc-change', x_axis='flops-drop', color_sorting='model'): y_axis = CUSTOM_HEADER_MAP[y_axis] x_axis = x_axis color_sorting = color_sorting # Drop rows where y_axis and x_axis columns are null data.replace('', np.nan, inplace=True) data.dropna(subset=[y_axis, x_axis], how='any', inplace=True) # Convert 'year' to string data[CUSTOM_HEADER_MAP['year']] = data[CUSTOM_HEADER_MAP['year']].astype(str) # Sort by y_axis data.sort_values(by=[y_axis], ascending=[False], inplace=True) # Get min and max for x and y axes x_min, x_max = data[x_axis].min(), data[x_axis].max() y_min, y_max = data[y_axis].min(), data[y_axis].max() if data is None or data.empty: # plot with title: # "No results found or bad query" return alt.Chart(pd.DataFrame({'x': [], 'y': []})).mark_point().encode().properties(title="No results found or bad query") # Create a selection that filters data based on the legend legend_selection = alt.selection_point(fields=[color_sorting], bind='legend') # Create a selection for hover hover_selection = alt.selection_point(on='mouseover', nearest=False, empty=True) # Create Altair scatter plot scatter = alt.Chart(data).mark_point().encode( x=alt.X(x_axis, title=x_axis, scale=alt.Scale(domain=(x_min-2, x_max+2))), y=alt.Y(y_axis, title=y_axis, scale=alt.Scale(domain=(y_min-2, y_max+2))), color=color_sorting, tooltip=[CUSTOM_HEADER_MAP['method'], CUSTOM_HEADER_MAP['model'], CUSTOM_HEADER_MAP['acc-pruned'], CUSTOM_HEADER_MAP['acc-change'], CUSTOM_HEADER_MAP['flops-pruned'], CUSTOM_HEADER_MAP['flops-drop'], CUSTOM_HEADER_MAP['year'], CUSTOM_HEADER_MAP['sec']], opacity=alt.condition(hover_selection, alt.value(1), alt.value(0.2)) ).add_params( legend_selection, hover_selection, ).transform_filter( legend_selection ).interactive() return scatter def item_selected(leaderboard: gr.Dataframe, evt: gr.SelectData): # evt.index # evt.value item = leaderboard.loc[leaderboard[CUSTOM_HEADER_MAP['method']] == evt.value] if len(item) == 0: return "✖️ Invalid cell! Please click on **Method Name** to see details...", "✖️ Invalid cell! Please click on **Method Name** to see details..." elif len(item) > 1: item = item.iloc[0] section = item[CUSTOM_HEADER_MAP['sec']] method = item[CUSTOM_HEADER_MAP['method']] # check if type is pandas Series if type(section) is pd.Series: section = section.iloc[0] if type(method) is pd.Series: method = method.iloc[0] sec_record = section_mapping[section] # (section, sub section) awesome_record = paper_mapping[method] # (paper, code) bibtex_record = bibtex_mapping[method] # (bibtex, score) # replace any KEY with value in venue_mapping for k, v in venue_mapping.items(): if k in bibtex_record: bibtex_record = bibtex_record.replace(k, v) # process section: (section, sub section) main_section = sec_record[0] sub_section = sec_record[1] # process awesome_record: " | paper | conf | type | code | " paper = "Not Recorded 😭" conf = "Not Recorded 😭" code = "Not Recorded 😭" if awesome_record is not None: splitted = awesome_record.split('|') paper = splitted[1].strip() conf = splitted[2].strip() code = splitted[-2].strip() if code == "" or code == "-": code = "Not Recorded 😭" text = f""" Section: {main_section} → {sub_section} ({section}) Paper: {paper} Venue: {conf} Code: {code} """ return text, bibtex_record def create_tab(app, dataset_name, dataset_id, df): dataset = dataset_name.lower() df_dataset = df[df['dataset'] == dataset] original_df_pd = df_dataset.copy() if dataset == 'cifar10': dataset_label = 'CIFAR-10' elif dataset == 'cifar100': dataset_label = 'CIFAR-100' elif dataset == 'imagenet': dataset_label = 'ImageNet-1K' else: raise ValueError(f"Unknown dataset: {dataset}") with gr.TabItem(dataset_label, id=dataset_id): with gr.Row(equal_height=True): with gr.Column(): with gr.Group(): with gr.Row(): gr.Markdown("**Search by below options:**", elem_classes="markdown-subtitle") with gr.Row(): search_box = gr.Textbox( placeholder="[press enter to search]", label="Model", show_label=True, ) search_box_method = gr.Textbox( placeholder="[press enter to search]", label="Method", show_label=True, ) search_box_year = gr.Textbox( placeholder="[press enter to search]", label="Year", show_label=True, ) search_box_section = gr.Textbox( placeholder="[press enter to search]", label="Section", show_label=True, ) with gr.Row(): acc_base_box = gr.Textbox( placeholder="[press enter to search]", label="Baseline Accuracy", info="E.g., `90` means search for baseline accuracy > 90%.", show_label=True, ) acc_box = gr.Textbox( placeholder="[press enter to search]", label="Accuracy After Pruning", info="E.g., `90` means search for accuracy after pruning > 90%.", show_label=True, ) acc_change = gr.Textbox( placeholder="[press enter to search]", label="Accuracy Drop", info="E.g., `2` means search for accuracy drop < 2%.", show_label=True, ) with gr.Row(): flops_base_box = gr.Textbox( placeholder="[press enter to search]", label="Baseline FLOPs", info="E.g., `100` means search for baseline FLOPs < 100M.", show_label=True, ) flops_box = gr.Textbox( placeholder="[press enter to search]", label="FLOPs After Pruning", info="E.g., `100` means search for FLOPs after pruning < 100M.", show_label=True, ) flops_drop = gr.Textbox( placeholder="[press enter to search]", label="FLOPs Drop", info="E.g., `50` means search for FLOPs drop > 50%.", show_label=True, ) with gr.Row(): param_base_box = gr.Textbox( placeholder="[press enter to search]", label="Baseline Parameters", info="E.g., `10` means search for baseline parameters < 10M.", show_label=True, ) param_box = gr.Textbox( placeholder="[press enter to search]", label="Parameters after Pruning", info="E.g., `10` means search for parameters after pruning < 10M.", show_label=True, ) param_drop = gr.Textbox( placeholder="[press enter to search]", label="Parameters Drop", info="E.g., `50` means search for parameters drop by > 50%.", show_label=True, ) with gr.Accordion(label="See Model Baselines", open=False): # text = gr.Text(value='Add baseline model specifications', label='Baseline FLOPs and Params', lines=2) baseline_dataset = baseline[baseline['dataset'] == dataset] baseline_no_dataset = baseline_dataset.drop(columns=['dataset']) baseline_no_dataset = baseline_no_dataset.rename(columns=CUSTOM_HEADER_MAP) baseline_df = gr.Dataframe( value=baseline_no_dataset, headers=list(baseline_no_dataset.columns), interactive=False, visible=True, wrap=True, ) with gr.Column(): with gr.Row(): with gr.Column(scale=1): sort_choice_box = gr.Radio(choices=[CUSTOM_HEADER_MAP["model"], CUSTOM_HEADER_MAP["sec"], CUSTOM_HEADER_MAP["year"]], value=CUSTOM_HEADER_MAP["model"], label="Draw with", info="Draw with [model, section, year]") with gr.Column(scale=1): x_axis_box = gr.Radio([CUSTOM_HEADER_MAP["flops-drop"], CUSTOM_HEADER_MAP["flops-pruned"]], value=CUSTOM_HEADER_MAP["flops-drop"], label="Set x-axis", info="Set x-axis to [FLOPs after pruning, FLOPs drop (%)]") with gr.Column(): plot_acc_change = gr.Plot(label="Plot of Accuracy Change (%)") y_axis_acc_change = gr.Text(value="acc-change", visible=False) plot_acc = gr.Plot(label="Plot of Accuracy After Pruing") y_axis_acc = gr.Text(value="acc-pruned", visible=False) original_df = gr.Dataframe( value=original_df_pd, headers=list(df_dataset.columns), max_rows=None, interactive=False, visible=False, ) with gr.Row(): # table df_dataset = df_dataset.rename(columns=CUSTOM_HEADER_MAP) leaderboard_table = gr.Dataframe( value=df_dataset, headers=list(df_dataset.columns), max_rows=None, interactive=False, visible=True, ) with gr.Row(): details = gr.Markdown(value="*Click any **Method Name** in above table to see details...*", elem_classes='markdown-text') bibtex_code = gr.Code("Click any Method Name in above table to see details...", label="BibTeX") # app.load(new_plot, outputs=[plot_acc_change]) app.load(make_plot, inputs=[leaderboard_table, y_axis_acc_change, x_axis_box, sort_choice_box], outputs=[plot_acc_change]) app.load(make_plot, inputs=[leaderboard_table, y_axis_acc, x_axis_box, sort_choice_box], outputs=[plot_acc]) boxes = [search_box, search_box_method, search_box_year, search_box_section, acc_base_box, acc_box, acc_change, flops_base_box, flops_box, flops_drop, param_base_box, param_box, param_drop] for search in boxes: search.submit(update_table, [original_df, search_box, search_box_method, search_box_year, search_box_section, acc_base_box, acc_box, acc_change, flops_base_box, flops_box, flops_drop, param_base_box, param_box, param_drop], outputs=[leaderboard_table]) leaderboard_table.change(make_plot, inputs=[leaderboard_table, y_axis_acc_change, x_axis_box, sort_choice_box], outputs=[plot_acc_change]) leaderboard_table.change(make_plot, inputs=[leaderboard_table, y_axis_acc, x_axis_box, sort_choice_box], outputs=[plot_acc]) leaderboard_table.select(item_selected, inputs=[leaderboard_table], outputs=[details, bibtex_code]) sort_choice_box.change(make_plot, [leaderboard_table, y_axis_acc, x_axis_box, sort_choice_box], outputs=[plot_acc]) sort_choice_box.change(make_plot, [leaderboard_table, y_axis_acc_change, x_axis_box, sort_choice_box], outputs=[plot_acc_change]) x_axis_box.change(make_plot, [leaderboard_table, y_axis_acc, x_axis_box, sort_choice_box], outputs=[plot_acc]) x_axis_box.change(make_plot, [leaderboard_table, y_axis_acc_change, x_axis_box, sort_choice_box], outputs=[plot_acc_change]) def main(): global df app = gr.Blocks(css=custom_css) with app: gr.Markdown(TITLE, elem_classes="markdown-title") with gr.Tabs(elem_classes="tab-buttons") as tabs: with gr.TabItem("👋 About", id=0): gr.Markdown(INTRO, elem_classes="markdown-text") gr.Code(BIBTEX, elem_classes="bibtex", label="BibTeX") with gr.TabItem("📑 User Guide", id=1): gr.Markdown("Guide to use this leaderboard", elem_classes="markdown-title") with gr.Accordion(label="0. Sections", open=True): gr.Markdown("## Sections", elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/overview.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Column(): text = """ We divide the webpage into below sections: 1. Dataset Tabs 2. Query Section 3. Data Plotting 4. Data Table More detailed functions are explained in the following sections. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Accordion(label="1. Dataset Tabs", open=False): gr.Markdown("# Dataset Tabs", elem_classes="markdown-text") with gr.Row(): gr.Image("src/images/cifar10-tab.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) gr.Image("src/images/cifar100-tab.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) gr.Image("src/images/imagenet-tab.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Row(): text = """ - Click the corresponding tabs to view the results of different datasets. - We currently support three datasets: CIFAR-10, CIFAR-100, and ImageNet-1K. - Results are 'isolated' for each dataset, i.e., the results of different datasets are not mixed together. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Accordion(label="2. Query Section", open=False): gr.Markdown("## Query Section", elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/query-overview.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Column(): text = """ The query box includes two parts - red box: query by paper attributes - blue box: query by experimental results Press [Enter] key to update. - update both plotting and table. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/use-case.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Column(): text = """ Example: Here, we provide a use case and show how query works. If a user wants to find methods that satisfy the followings: 1. Select Dataset: ImageNet-1K 2. Select Model: ResNet-50 3. Select Pruning Method: Regularization-based Pruning 4. Target 1: Accuracy after pruning > 75\% 5. Target 2: Pruned FLOPs > 40% 6. Target 3: Model size after pruning < 30M By entering the requirements to the corresponding query box, we can narrow down the results and compare the remaining ones. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Accordion(label="3. Data Plotting", open=False): gr.Markdown("## Data Plotting", elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/plotting-overview.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Column(): text = """ The data plotting section can be split into two parts: - red box: contains two radio buttons to select: - (left) Group colors by ‘model’, ‘section’, or ‘year’. - (right) Change x-axis of the plots to ‘FLOPs drop (%)’ or ‘FLOPs after pruning (M)’. - blue box: interactive plots """ gr.Markdown(text, elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/group-model.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ Group by Model (default) X-axis: FLOPs drop (%) (default) """ gr.Markdown(text, elem_classes="markdown-text") with gr.Column(): gr.Image("src/images/group-section.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ Group by Section X-axis: FLOPs drop (%) (default) """ gr.Markdown(text, elem_classes="markdown-text") with gr.Column(): gr.Image("src/images/group-year.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ Group by Year X-axis: FLOPs drop (%) (default) """ gr.Markdown(text, elem_classes="markdown-text") with gr.Column(): gr.Image("src/images/flops-pruned.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ Group by Model (default) X-axis: FLOPs after pruning (M) """ gr.Markdown(text, elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/default.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ Default Figure """ gr.Markdown(text, elem_classes="markdown-text") with gr.Column(): gr.Image("src/images/drag.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ 1. Shift the graph by dragging. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Column(): gr.Image("src/images/zoom-out.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ 2. Zoom-in/out by scrolling. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/hover.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ 3. Hover over the data point to see the details. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Column(): gr.Image("src/images/legend-before.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ 4. Click any legend to filter out others. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Column(): gr.Image("src/images/legend-after.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) text = """ 4. Click any legend to filter out others. 5. Click white spaces/Double Click to restore to default scaling and legends. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Accordion(label="4. Data Table", open=False): gr.Markdown("## Data Table", elem_classes="markdown-text") with gr.Row(): with gr.Column(): with gr.Row(): gr.Image("src/images/drop-down-crop.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) gr.Image("src/images/expand.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Column(): text = """ Click to the expand the table - The expanded table contains the baseline FLOPs and Parameters for each model. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/sort_btn.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Column(): text = """ Click the sort button: - Sort in ascending order. - click more than once to toggle ascending/descending. """ gr.Markdown(text, elem_classes="markdown-text") with gr.Row(): with gr.Column(): gr.Image("src/images/detail.png", elem_classes="markdown-image", show_label=False, interactive=False, show_download_button=False) with gr.Column(): text = """ Click any method name (highlighted in the red box) to show details of the paper (blue box). The details include: - detailed section - link of paper - venue of publication - released code (if any) - the BibTex used in our paper """ gr.Markdown(text, elem_classes="markdown-text") with gr.Tabs(elem_classes="tab-buttons") as tabs: create_tab(app, "cifar10", 0, df) create_tab(app, "cifar100", 1, df) create_tab(app, "imagenet", 2, df) app.launch() if __name__ == "__main__": main()