leaderboard / app.py
rwightman's picture
rwightman HF staff
Update app.py
caa6729 verified
import fnmatch
import gradio as gr
import pandas as pd
import plotly.express as px
from rapidfuzz import fuzz
import re
def load_leaderboard():
# Load validation / test CSV files
results_csv_files = {
'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
'v2': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenetv2-matched-frequency.csv',
'sketch': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-sketch.csv',
'a': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-a.csv',
'r': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-r.csv',
}
# Load benchmark CSV files
benchmark_csv_files = {
'amp-nchw-pt240-cu124-rtx4090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx4090.csv',
'amp-nhwc-pt240-cu124-rtx4090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt240-cu124-rtx4090.csv',
'amp-nchw-pt240-cu124-rtx4090-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx4090-dynamo.csv',
'amp-nchw-pt240-cu124-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx3090.csv',
'amp-nhwc-pt240-cu124-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv',
'fp32-nchw-pt240-cpu-i9_10940x-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt240-cpu-i9_10940x-dynamo.csv',
'fp32-nchw-pt240-cpu-i7_12700h-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt240-cpu-i7_12700h-dynamo.csv',
}
dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
# Clean up dataframes
remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
for df in dataframes.values():
for col in remove_column_names:
if col in df.columns:
df.drop(columns=[col], inplace=True)
# Rename / process results columns
for name, df in dataframes.items():
df.rename(columns={"top1": f"{name}_top1", "top5": f"{name}_top5"}, inplace=True)
df['arch_name'] = df['model'].apply(lambda x: x.split('.')[0])
# Process benchmark dataframes
for name, df in bench_dataframes.items():
df['arch_name'] = df['model']
df.rename(columns={'infer_img_size': 'img_size'}, inplace=True)
# Merge all result dataframes
result = dataframes['imagenet']
for name, df in dataframes.items():
if name != 'imagenet':
result = pd.merge(result, df, on=['arch_name', 'model', 'img_size', 'crop_pct', 'interpolation'], how='outer')
# Calculate average scores
top1_columns = [col for col in result.columns if col.endswith('_top1') and not col == 'a_top1']
top5_columns = [col for col in result.columns if col.endswith('_top5') and not col == 'a_top5']
result['avg_top1'] = result[top1_columns].mean(axis=1)
result['avg_top5'] = result[top5_columns].mean(axis=1)
# Create fully merged dataframes for each benchmark set
merged_dataframes = {}
for bench_name, bench_df in bench_dataframes.items():
merged_df = pd.merge(result, bench_df, on=['arch_name', 'img_size'], how='left', suffixes=('', '_benchmark'))
# Calculate TFLOP/s
merged_df['infer_tflop_s'] = merged_df['infer_samples_per_sec'] * merged_df['infer_gmacs'] * 2 / 1000
# Reorder columns
first_columns = ['model', 'img_size', 'avg_top1', 'avg_top5']
other_columns = [col for col in merged_df.columns if col not in first_columns]
merged_df = merged_df[first_columns + other_columns].copy(deep=True)
# Drop columns that are no longer needed / add too much noise
merged_df.drop('arch_name', axis=1, inplace=True)
merged_df.drop('crop_pct', axis=1, inplace=True)
merged_df.drop('interpolation', axis=1, inplace=True)
merged_df.drop('model_benchmark', axis=1, inplace=True)
merged_df['infer_usec_per_sample'] = 1e6 / merged_df.infer_samples_per_sec
merged_df['highlighted'] = False
merged_df = merged_df.round(2)
merged_dataframes[bench_name] = merged_df
return merged_dataframes
REGEX_PREFIX = "re:"
def auto_match(pattern, text):
# Check if it's a regex pattern (starts with 're:')
if pattern.startswith(REGEX_PREFIX):
regex_pattern = pattern[len(REGEX_PREFIX):].strip()
try:
return bool(re.match(regex_pattern, text, re.IGNORECASE))
except re.error:
# If it's an invalid regex, return False
return False
# Check if it's a wildcard pattern
elif any(char in pattern for char in ['*', '?']):
return fnmatch.fnmatch(text.lower(), pattern.lower())
# If not regex or wildcard, use fuzzy matching
else:
return fuzz.partial_ratio(
pattern.lower(), text.lower(), score_cutoff=90) > 0
def filter_leaderboard(df, model_name, sort_by):
if not model_name:
return df.sort_values(by=sort_by, ascending=False)
mask = df['model'].apply(lambda x: auto_match(model_name, x))
filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
return filtered_df
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter, log_x, log_y):
selected_color = 'orange'
fig = px.scatter(
df,
x=x_axis,
y=y_axis,
log_x=log_x,
log_y=log_y,
hover_data=['model'],
trendline='ols',
trendline_options=dict(log_x=True, log_y=True),
color='highlighted',
color_discrete_map={True: selected_color, False: 'blue'},
title=f'{y_axis} vs {x_axis}'
)
# Create legend labels
legend_labels = {}
if highlight_filter:
legend_labels[True] = f'{highlight_filter}'
legend_labels[False] = f'{model_filter or "all models"}'
else:
legend_labels[False] = f'{model_filter or "all models"}'
# Update legend
for trace in fig.data:
if isinstance(trace.marker.color, str): # This is for the scatter traces
trace.name = legend_labels.get(trace.marker.color == selected_color, '')
fig.update_layout(
showlegend=True,
legend_title_text='Model Selection'
)
return fig
# Load the leaderboard data
merged_dataframes = load_leaderboard()
# Define the available columns for sorting and plotting
sort_columns = ['avg_top1', 'avg_top5', 'imagenet_top1', 'imagenet_top5', 'infer_samples_per_sec', 'infer_usec_per_sample', 'param_count', 'infer_gmacs', 'infer_macts', 'infer_tflop_s']
plot_columns = ['infer_samples_per_sec', 'infer_usec_per_sample', 'infer_gmacs', 'infer_macts', 'infer_tflop_s', 'param_count', 'avg_top1', 'avg_top5', 'imagenet_top1', 'imagenet_top5']
DEFAULT_SEARCH = ""
DEFAULT_SORT = "avg_top1"
DEFAULT_X = "infer_samples_per_sec"
DEFAULT_Y = "avg_top1"
DEFAULT_BM = 'amp-nchw-pt240-cu124-rtx4090'
def col_formatter(value, precision=None):
if isinstance(value, int):
return f'{value:d}'
elif isinstance(value, float):
return f'{value:.{precision}f}' if precision is not None else f'{value:g}'
return str(value)
def update_leaderboard_and_plot(
model_name=DEFAULT_SEARCH,
highlight_name=None,
sort_by=DEFAULT_SORT,
x_axis=DEFAULT_X,
y_axis=DEFAULT_Y,
benchmark_selection=DEFAULT_BM,
log_x=True,
log_y=True,
):
df = merged_dataframes[benchmark_selection].copy()
filtered_df = filter_leaderboard(df, model_name, sort_by)
# Apply the highlight filter to the entire dataset so the output will be union (comparison) if the filters are disjoint
highlight_df = filter_leaderboard(df, highlight_name, sort_by) if highlight_name else None
# Combine filtered_df and highlight_df, removing duplicates
if highlight_df is not None:
combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True)
combined_df = combined_df.sort_values(by=sort_by, ascending=False)
combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model'])
else:
combined_df = filtered_df
combined_df['highlighted'] = False
fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name, log_x, log_y)
display_df = combined_df.drop(columns=['highlighted'])
display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(
#{
# 'infer_batch_size': lambda x: col_formatter(x), # Integer column
#},
precision=2,
)
return display_df, fig
with gr.Blocks(title="The timm Leaderboard") as app:
gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
with gr.Row():
search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
with gr.Row():
highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient")
with gr.Row():
x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
y_axis = gr.Dropdown(choices=plot_columns, label="Y-axis", value=DEFAULT_Y)
with gr.Row():
benchmark_dropdown = gr.Dropdown(
choices=list(merged_dataframes.keys()),
label="Benchmark Selection",
value=DEFAULT_BM,
)
with gr.Row():
log_x = gr.Checkbox(label="Log scale X-axis", value=True)
log_y = gr.Checkbox(label="Log scale Y-axis", value=True)
update_btn = gr.Button(value="Update", variant="primary")
leaderboard = gr.Dataframe()
plot = gr.Plot()
inputs = [search_bar, highlight_bar, sort_dropdown, x_axis, y_axis, benchmark_dropdown, log_x, log_y]
outputs = [leaderboard, plot]
app.load(update_leaderboard_and_plot, outputs=outputs)
search_bar.submit(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
highlight_bar.submit(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
sort_dropdown.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
x_axis.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
y_axis.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
benchmark_dropdown.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
log_x.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
app.launch()