Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import json | |
import os | |
import re | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from huggingface_hub import HfApi, hf_hub_download | |
from src.backend import backend_routine | |
from src.css_html_js import dark_mode_gradio_js | |
from src.logging import configure_root_logger, setup_logger | |
configure_root_logger() | |
logger = setup_logger(__name__) | |
API = HfApi(token=os.environ.get("TOKEN")) | |
RESULTS_REPO = f"open-rl-leaderboard/results" | |
ALL_ENV_IDS = { | |
"Atari": [ | |
"BeamRiderNoFrameskip-v4", | |
"BreakoutNoFrameskip-v4", | |
], | |
"Box2D": [ | |
"LunarLander-v2", | |
"BipedalWalker-v3", | |
], | |
"Classic control": [ | |
"CartPole-v1", | |
"MountainCar-v0", | |
], | |
"MuJoCo": [ | |
"Hopper-v4", | |
"HalfCheetah-v4", | |
], | |
} | |
def get_leaderboard_df(): | |
# List all results files in results repo | |
pattern = re.compile(r"^[^/]*/[^/]*/[^/]*results_[a-f0-9]+\.json$") | |
filenames = API.list_repo_files(RESULTS_REPO, repo_type="dataset") | |
filenames = [filename for filename in filenames if pattern.match(filename)] | |
data = [] | |
for filename in filenames: | |
path = hf_hub_download(repo_id=RESULTS_REPO, filename=filename, repo_type="dataset") | |
with open(path) as fp: | |
report = json.load(fp) | |
user_id, model_id = report["config"]["model_id"].split("/") | |
row = {"user_id": user_id, "model_id": model_id} | |
if report["status"] == "DONE" and len(report["results"]) > 0: | |
env_ids = list(report["results"].keys()) | |
assert len(env_ids) == 1, "Only one environment supported for the moment" | |
row["env_id"] = env_ids[0] | |
row["mean_episodic_return"] = np.mean(report["results"][env_ids[0]]["episodic_returns"]) | |
data.append(row) | |
df = pd.DataFrame(data) # create DataFrame | |
df = df.fillna("") # replace NaN values with empty strings | |
return df | |
TITLE = """ | |
๐ Open RL Leaderboard | |
""" | |
INTRODUCTION_TEXT = """ | |
Welcome to the Open RL Leaderboard! This is a community-driven benchmark for reinforcement learning models. | |
""" | |
ABOUT_TEXT = """ | |
The Open RL Leaderboard is a community-driven benchmark for reinforcement learning models. | |
""" | |
def select_env(df: pd.DataFrame, env_id: str): | |
df = df[df["env_id"] == env_id] | |
# Add the ranking | |
df = df.sort_values("mean_episodic_return", ascending=False) | |
df["ranking"] = np.arange(1, len(df) + 1) | |
# Add hyperlinks | |
for index, row in df.iterrows(): | |
user_id = row["user_id"] | |
model_id = row["model_id"] | |
df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})" | |
df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})" | |
df = df[["ranking", "user_id", "model_id", "mean_episodic_return"]] | |
return df.values.tolist() | |
with gr.Blocks(js=dark_mode_gradio_js) as demo: | |
gr.HTML(TITLE) | |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("๐ Leaderboard"): | |
df = get_leaderboard_df() | |
for env_domain, env_ids in ALL_ENV_IDS.items(): | |
with gr.TabItem(env_domain): | |
for env_id in env_ids: | |
with gr.TabItem(env_id): | |
with gr.Row(equal_height=False): | |
gr.components.Dataframe( | |
value=select_env(df, env_id), | |
headers=["๐ Ranking", "๐ง User", "๐ค Model id", "๐ Mean episodic return"], | |
datatype=["number", "markdown", "markdown", "number"], | |
row_count=(10, "fixed"), | |
scale=3, | |
) | |
gr.Video( | |
"https://huggingface.co/qgallouedec/MsPacmanNoFrameskip-v4-dqn_atari-seed1/resolve/main/replay.mp4", | |
autoplay=True, | |
scale=1, | |
min_width=50, | |
) | |
with gr.TabItem("๐ About", elem_id="llm-benchmark-tab-table", id=2): | |
gr.Markdown(ABOUT_TEXT) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(func=backend_routine, trigger="interval", seconds=0.5 * 60, max_instances=1) | |
scheduler.start() | |
if __name__ == "__main__": | |
demo.queue().launch() | |