leaderboard / app.py
Quentin Gallouédec
handle video not in repo and count the number of models
a3eda6f
raw
history blame
No virus
10.9 kB
import glob
import json
import logging
import os
import gradio as gr
import numpy as np
import pandas as pd
import scipy.stats
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi
from src.backend import backend_routine
from src.logging import configure_root_logger, setup_logger
configure_root_logger()
logger = setup_logger(__name__)
logging.getLogger("absl").setLevel(logging.WARNING)
API = HfApi(token=os.environ.get("TOKEN"))
RESULTS_REPO = "open-rl-leaderboard/results"
REFRESH_RATE = 5 * 60 # 5 minutes
ALL_ENV_IDS = {
"Atari": [
"AdventureNoFrameskip-v4",
"AirRaidNoFrameskip-v4",
"AlienNoFrameskip-v4",
"AmidarNoFrameskip-v4",
"AssaultNoFrameskip-v4",
"AsterixNoFrameskip-v4",
"AsteroidsNoFrameskip-v4",
"AtlantisNoFrameskip-v4",
"BankHeistNoFrameskip-v4",
"BattleZoneNoFrameskip-v4",
"BeamRiderNoFrameskip-v4",
"BerzerkNoFrameskip-v4",
"BowlingNoFrameskip-v4",
"BoxingNoFrameskip-v4",
"BreakoutNoFrameskip-v4",
"CarnivalNoFrameskip-v4",
"CentipedeNoFrameskip-v4",
"ChopperCommandNoFrameskip-v4",
"CrazyClimberNoFrameskip-v4",
"DefenderNoFrameskip-v4",
"DemonAttackNoFrameskip-v4",
"DoubleDunkNoFrameskip-v4",
"ElevatorActionNoFrameskip-v4",
"EnduroNoFrameskip-v4",
"FishingDerbyNoFrameskip-v4",
"FreewayNoFrameskip-v4",
"FrostbiteNoFrameskip-v4",
"GopherNoFrameskip-v4",
"GravitarNoFrameskip-v4",
"HeroNoFrameskip-v4",
"IceHockeyNoFrameskip-v4",
"JamesbondNoFrameskip-v4",
"JourneyEscapeNoFrameskip-v4",
"KangarooNoFrameskip-v4",
"KrullNoFrameskip-v4",
"KungFuMasterNoFrameskip-v4",
"MontezumaRevengeNoFrameskip-v4",
"MsPacmanNoFrameskip-v4",
"NameThisGameNoFrameskip-v4",
"PhoenixNoFrameskip-v4",
"PitfallNoFrameskip-v4",
"PongNoFrameskip-v4",
"PooyanNoFrameskip-v4",
"PrivateEyeNoFrameskip-v4",
"QbertNoFrameskip-v4",
"RiverraidNoFrameskip-v4",
"RoadRunnerNoFrameskip-v4",
"RobotankNoFrameskip-v4",
"SeaquestNoFrameskip-v4",
"SkiingNoFrameskip-v4",
"SolarisNoFrameskip-v4",
"SpaceInvadersNoFrameskip-v4",
"StarGunnerNoFrameskip-v4",
"TennisNoFrameskip-v4",
"TimePilotNoFrameskip-v4",
"TutankhamNoFrameskip-v4",
"UpNDownNoFrameskip-v4",
"VentureNoFrameskip-v4",
"VideoPinballNoFrameskip-v4",
"WizardOfWorNoFrameskip-v4",
"YarsRevengeNoFrameskip-v4",
"ZaxxonNoFrameskip-v4",
],
"Box2D": [
"BipedalWalker-v3",
"BipedalWalkerHardcore-v3",
"CarRacing-v2",
"LunarLander-v2",
"LunarLanderContinuous-v2",
],
"Toy text": [
"Blackjack-v1",
"CliffWalking-v0",
"FrozenLake-v1",
"FrozenLake8x8-v1",
],
"Classic control": [
"Acrobot-v1",
"CartPole-v1",
"MountainCar-v0",
"MountainCarContinuous-v0",
"Pendulum-v1",
],
"MuJoCo": [
"Ant-v4",
"HalfCheetah-v4",
"Hopper-v4",
"Humanoid-v4",
"HumanoidStandup-v4",
"InvertedDoublePendulum-v4",
"InvertedPendulum-v4",
"Pusher-v4",
"Reacher-v4",
"Swimmer-v4",
"Walker2d-v4",
],
}
def iqm(x):
return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)
def get_leaderboard_df():
logger.info("Downloading results")
dir_path = API.snapshot_download(repo_id=RESULTS_REPO, repo_type="dataset")
pattern = os.path.join(dir_path, "**", "results_*.json")
filenames = glob.glob(pattern, recursive=True)
data = []
for filename in filenames:
try:
with open(filename) as fp:
report = json.load(fp)
user_id, model_id = report["config"]["model_id"].split("/")
row = {"user_id": user_id, "model_id": model_id, "model_sha": report["config"]["model_sha"]}
if report["status"] == "DONE" and len(report["results"]) > 0:
env_ids = list(report["results"].keys())
assert len(env_ids) == 1, "Only one environment supported for the moment"
row["env_id"] = env_ids[0]
row["iqm_episodic_return"] = iqm(report["results"][env_ids[0]]["episodic_returns"])
data.append(row)
except Exception as e:
logger.error(f"Error while processing {filename}: {e}")
df = pd.DataFrame(data) # create DataFrame
df = df.fillna("") # replace NaN values with empty strings
return df
def select_env(df: pd.DataFrame, env_id: str):
df = df[df["env_id"] == env_id]
df = df.sort_values("iqm_episodic_return", ascending=False)
df["ranking"] = np.arange(1, len(df) + 1)
return df
def format_df(df: pd.DataFrame):
# Add hyperlinks
df = df.copy()
for index, row in df.iterrows():
user_id = row["user_id"]
model_id = row["model_id"]
df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})"
df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})"
# Keep only the relevant columns
df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]]
return df.values.tolist()
def refresh_video(df, env_id):
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
model_sha = env_df.iloc[0]["model_sha"]
repo_id = f"{user_id}/{model_id}"
try:
video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
return video_path
except Exception as e:
logger.error(f"Error while downloading video for {env_id}: {e}")
return None
else:
return None
def refresh_one_video(df, env_id):
def inner():
return refresh_video(df, env_id)
return inner
def refresh_winner(df, env_id):
# print("Refreshing winners")
env_df = select_env(df, env_id)
if not env_df.empty:
user_id = env_df.iloc[0]["user_id"]
model_id = env_df.iloc[0]["model_id"]
url = f"https://huggingface.co/{user_id}/{model_id}"
return f"""## {env_id}
### 🏆 [Best model]({url}) 🏆"""
else:
return f"""## {env_id}
This leaderboard is quite empty... 😢
Be the first to submit your model!
Check the tab "🚀 Getting my agent evaluated"
"""
def refresh_num_models(df):
return f"The leaderboard currently contains {len(df):,} models."
css = """
.generating {
border: none;
}
h2 {
text-align: center;
}
h3 {
text-align: center;
}
"""
def update_globals():
global dataframes, winner_texts, video_pathes, num_models_str, df
df = get_leaderboard_df()
all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
num_models_str = refresh_num_models(df)
update_globals()
def refresh():
global dataframes, winner_texts, num_models_str
return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str]
with gr.Blocks(css=css) as demo:
with open("texts/heading.md") as fp:
gr.Markdown(fp.read())
num_models_md = gr.Markdown()
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏅 Leaderboard"):
all_gr_dfs = {}
all_gr_winners = {}
all_gr_videos = {}
for env_domain, env_ids in ALL_ENV_IDS.items():
with gr.TabItem(env_domain):
for env_id in env_ids:
# If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
with gr.TabItem(tab_env_id) as tab:
logger.info(f"Creating tab for {env_id}")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
gr_df = gr.components.Dataframe(
headers=["🏆", "🧑 User", "🤖 Model id", "📊 IQM episodic return"],
datatype=["number", "markdown", "markdown", "number"],
)
with gr.Column(scale=1):
with gr.Row(): # Display the env_id and the winner
gr_winner = gr.Markdown()
with gr.Row(): # Play the video of the best model
gr_video = gr.PlayableVideo( # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
min_width=50,
show_download_button=False,
show_share_button=False,
show_label=False,
interactive=False,
)
all_gr_dfs[env_id] = gr_df
all_gr_winners[env_id] = gr_winner
all_gr_videos[env_id] = gr_video
tab.select(refresh_one_video(df, env_id), outputs=[gr_video])
# Load the first video of the first environment
demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])
with gr.TabItem("🚀 Getting my agent evaluated"):
with open("texts/getting_my_agent_evaluated.md") as fp:
gr.Markdown(fp.read())
with gr.TabItem("📝 About"):
with open("texts/about.md") as fp:
gr.Markdown(fp.read())
demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])
scheduler = BackgroundScheduler()
scheduler.add_job(func=backend_routine, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
scheduler.start()
if __name__ == "__main__":
demo.queue().launch()