import pickle import pandas as pd import gradio as gr import plotly.express as px from utils import ( KEY_TO_CATEGORY_NAME, PROPRIETARY_LICENSES, download_latest_data_from_space, ) # with gr.NO_RELOAD: ################### ### Load Data ################### # gather ELO data latest_elo_file_local = download_latest_data_from_space( repo_id="lmsys/chatbot-arena-leaderboard", file_type="pkl" ) with open(latest_elo_file_local, "rb") as fin: elo_results = pickle.load(fin) arena_dfs = {} for k in KEY_TO_CATEGORY_NAME.keys(): if k not in elo_results: continue arena_dfs[KEY_TO_CATEGORY_NAME[k]] = elo_results[k]["leaderboard_table_df"] # gather open llm leaderboard data latest_leaderboard_file_local = download_latest_data_from_space( repo_id="lmsys/chatbot-arena-leaderboard", file_type="csv" ) leaderboard_df = pd.read_csv(latest_leaderboard_file_local) ################### ### Prepare Data ################### # merge leaderboard data with ELO data merged_dfs = {} for k, v in arena_dfs.items(): merged_dfs[k] = ( pd.merge(arena_dfs[k], leaderboard_df, left_index=True, right_on="key") .sort_values("rating", ascending=False) .reset_index(drop=True) ) # add release dates into the merged data release_date_mapping = pd.read_json("release_date_mapping.json", orient="records") for k, v in merged_dfs.items(): merged_dfs[k] = pd.merge( merged_dfs[k], release_date_mapping[["key", "Release Date"]], on="key" ) df = merged_dfs["Overall"] df["License"] = df["License"].apply( lambda x: "Proprietary LLM" if x in PROPRIETARY_LICENSES else "Open LLM" ) df["Release Date"] = pd.to_datetime(df["Release Date"]) df["Month-Year"] = df["Release Date"].dt.to_period("M") df["rating"] = df["rating"].round() ################### ### Plot Data ################### date_updated = elo_results["full"]["last_updated_datetime"].split(" ")[0] min_elo_score = df["rating"].min().round() max_elo_score = df["rating"].max().round() upper_models_per_month = int( df.groupby(["Month-Year", "License"])["rating"].apply(lambda x: x.count()).max() ) def build_plot(min_score, max_models_per_month, toggle_annotations): filtered_df = df[(df["rating"] >= min_score)] filtered_df = ( filtered_df.groupby(["Month-Year", "License"]) .apply(lambda x: x.nlargest(max_models_per_month, "rating")) .reset_index(drop=True) ) fig = px.scatter( filtered_df, x="Release Date", y="rating", color="License", hover_name="Model", hover_data=["Organization", "License"], trendline="ols", title=f"Proprietary vs Open LLMs (LMSYS Arena ELO as of {date_updated})", labels={"rating": "Arena ELO", "Release Date": "Release Date"}, height=700, template="seaborn", ) fig.update_traces(marker=dict(size=10, opacity=0.6)) if toggle_annotations: # get the points to annotate (only the highest rated model per month per license) idx_to_annotate = filtered_df.groupby(["Month-Year", "License"])[ "rating" ].idxmax() points_to_annotate_df = filtered_df.loc[idx_to_annotate] for i, row in points_to_annotate_df.iterrows(): fig.add_annotation( x=row["Release Date"], y=row["rating"], text=row["Model"], showarrow=True, arrowhead=0, ) return fig demo = gr.Blocks() with demo: gr.Markdown("# Proprietary vs Open LLMs (LMSYS Arena ELO)") with gr.Row(): min_score = gr.Slider( minimum=min_elo_score, maximum=max_elo_score, value=800, step=50, label="Minimum ELO Score", ) max_models_per_month = gr.Slider( value=upper_models_per_month, minimum=1, maximum=upper_models_per_month, step=1, label="Max Models per Month (per License)", ) toggle_annotations = gr.Radio( choices=[True, False], label="Overlay Best Model Name", value=False ) # Show plot plot = gr.Plot() demo.load( fn=build_plot, inputs=[min_score, max_models_per_month, toggle_annotations], outputs=plot, ) min_score.change( fn=build_plot, inputs=[min_score, max_models_per_month, toggle_annotations], outputs=plot, ) max_models_per_month.change( fn=build_plot, inputs=[min_score, max_models_per_month, toggle_annotations], outputs=plot, ) toggle_annotations.change( fn=build_plot, inputs=[min_score, max_models_per_month, toggle_annotations], outputs=plot, ) demo.launch() # if __name__ == "__main__":