|
import pandas as pd |
|
import plotly.express as px |
|
import gradio as gr |
|
|
|
|
|
data_path = '0926-OCRBench-opensource.csv' |
|
data = pd.read_csv(data_path).fillna(0) |
|
|
|
|
|
dtype_dict = { |
|
"Model": str, |
|
"Param (B)": float, |
|
"OCRBench":int, |
|
"Text Recognition":int, |
|
"Scene Text-centric VQA":int, |
|
"Document Oriented VQA":int, |
|
"KIE":int, |
|
"Handwritten Math Expression Recognition":int} |
|
|
|
|
|
|
|
data_valid = data[:26].copy() |
|
data_valid = data_valid.astype(dtype_dict) |
|
data_valid.drop(columns=['Unnamed: 11'], inplace=True) |
|
|
|
|
|
def categorize_model(model): |
|
if model in ["H2OVL-Mississippi-2B", "H2OVL-Mississippi-1B"]: |
|
return "H2OVLs" |
|
elif model.startswith("doctr"): |
|
return "traditional ocr models" |
|
else: |
|
return "Other" |
|
|
|
|
|
|
|
data_valid["Category"] = data_valid["Model"].apply(categorize_model) |
|
|
|
|
|
|
|
|
|
def plot_metric(selected_metric): |
|
filtered_data = data_valid[data_valid[selected_metric] !=0 ] |
|
|
|
|
|
fig = px.scatter( |
|
filtered_data, |
|
x="Param (B)", |
|
y=selected_metric, |
|
text="Model", |
|
color="Category", |
|
title=f"{selected_metric} vs Model Size" |
|
) |
|
|
|
fig.update_traces(marker=dict(size=10), mode='markers+text', textposition="middle right", textfont=dict(size=8)) |
|
|
|
max_x_value = filtered_data["Param (B)"].max() |
|
fig.update_layout( |
|
xaxis_range=[0, max_x_value + 5], |
|
xaxis_title="Model Size (B)", |
|
yaxis_title=selected_metric, |
|
showlegend=False, |
|
height=800, |
|
margin=dict(t=50, l=50, r=100, b=50), |
|
) |
|
|
|
|
|
fig.update_traces(texttemplate='%{text}') |
|
|
|
return fig |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks() as interface: |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
plot = gr.Plot(value=plot_metric("OCRBench"), label="OCR Benchmark Metrics") |
|
with gr.Column(scale=1): |
|
metrics = list(data_valid.columns[5:-1]) |
|
dropdown = gr.Dropdown(metrics, label="Select Metric", value="OCRBench") |
|
|
|
|
|
dropdown.change(fn=plot_metric, inputs=dropdown, outputs=plot) |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
create_interface().launch() |
|
|
|
|