File size: 3,079 Bytes
2b1ed69 52c06b0 2b1ed69 a2ef2b2 2b1ed69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import pandas as pd
import plotly.express as px
import gradio as gr
data_path = '0926-OCRBench-opensource.csv'
data = pd.read_csv(data_path).fillna(0)
# set the data types for the columns
dtype_dict = {
"Model": str,
"Param (B)": float,
"OCRBench":int,
"Text Recognition":int,
"Scene Text-centric VQA":int,
"Document Oriented VQA":int,
"KIE":int,
"Handwritten Math Expression Recognition":int}
# preprocess the dataframe
data_valid = data[:26].copy()
data_valid = data_valid.astype(dtype_dict)
data_valid.drop(columns=['Unnamed: 11'], inplace=True)
# Add a new column that assigns categories to Model A, Model B, and Model C, and 'Other' to the rest
def categorize_model(model):
if model in ["H2OVL-Mississippi-2B", "H2OVL-Mississippi-1B"]:
return "H2OVLs"
elif model.startswith("doctr"): # Third group for ocr models
return "traditional ocr models"
else:
return "Other"
# Apply the categorization to create a new column
data_valid["Category"] = data_valid["Model"].apply(categorize_model)
# ploting
def plot_metric(selected_metric):
filtered_data = data_valid[data_valid[selected_metric] !=0 ]
# Create the scatter plot with different colors for "Special" and "Other"
fig = px.scatter(
filtered_data,
x="Param (B)",
y=selected_metric,
text="Model",
color="Category", # Different color for Special and Other categories
title=f"{selected_metric} vs Model Size"
)
fig.update_traces(marker=dict(size=10), mode='markers+text', textposition="middle right", textfont=dict(size=8))
# Extend the x-axis range
max_x_value = filtered_data["Param (B)"].max()
fig.update_layout(
xaxis_range=[0, max_x_value + 5], # Extend the x-axis range to give more space for text
xaxis_title="Model Size (B)",
yaxis_title=selected_metric,
showlegend=False,
height=800,
margin=dict(t=50, l=50, r=100, b=50), # Increase right margin for more space
)
# Use texttemplate to ensure full model name is displayed
fig.update_traces(texttemplate='%{text}')
return fig
# Gradio Blocks Interface
def create_interface():
with gr.Blocks() as interface:
with gr.Row():
with gr.Column(scale=4): # Column for the plot (takes 4 parts of the total space)
plot = gr.Plot(value=plot_metric("OCRBench"), label="OCR Benchmark Metrics") # default plot component initially
with gr.Column(scale=1): # Column for the dropdown (takes 1 part of the total space)
metrics = list(data_valid.columns[5:-1]) # List of metric columns (excluding 'Model' and 'Parameter Size')
dropdown = gr.Dropdown(metrics, label="Select Metric", value="OCRBench")
# Update the plot when dropdown selection changes
dropdown.change(fn=plot_metric, inputs=dropdown, outputs=plot)
return interface
# Launch the interface
if __name__ == "__main__":
create_interface().launch()
|