File size: 10,446 Bytes
b165a4b
74e3b17
0811d37
74e3b17
de52ad3
74e3b17
400662c
74e3b17
c67a861
5174522
95c19d6
c67a861
 
 
 
 
 
 
2a5f9fb
c67a861
b165a4b
2a73469
69cf5b3
c67a861
0660028
69cf5b3
 
0ef2585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69cf5b3
 
 
3922a8b
 
103ee13
 
3922a8b
 
 
103ee13
3922a8b
 
69cf5b3
 
3922a8b
69cf5b3
 
3922a8b
 
69cf5b3
 
103ee13
69cf5b3
3922a8b
 
 
103ee13
 
 
 
 
 
69cf5b3
74c08c9
 
 
 
 
 
 
 
 
 
 
69cf5b3
74e3b17
 
6d58c89
400662c
6d58c89
 
74e3b17
0660028
c67a861
 
 
 
 
74e3b17
 
 
69cf5b3
 
6d58c89
69cf5b3
3922a8b
74e3b17
3922a8b
 
69cf5b3
3922a8b
69cf5b3
 
 
 
 
 
041b899
6d58c89
69cf5b3
95c19d6
 
0660028
 
 
 
 
c67a861
0660028
a3eda6f
c67a861
a3eda6f
 
 
 
0660028
 
400662c
 
0660028
 
 
400662c
0660028
400662c
 
0660028
 
 
 
 
 
 
 
48ae20c
02fb3fc
0660028
 
48ae20c
54bed10
 
4f47d86
 
54bed10
c67a861
 
a3eda6f
 
48ae20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0660028
 
a3eda6f
0660028
 
 
 
 
a3eda6f
0660028
 
 
 
 
 
a3eda6f
 
0660028
 
48ae20c
263af70
 
a3eda6f
74e3b17
69cf5b3
0660028
 
 
69cf5b3
 
 
0660028
0ef2585
400662c
c67a861
69cf5b3
48ae20c
 
6d58c89
48ae20c
 
 
 
 
 
0660028
48ae20c
 
 
 
0660028
48ae20c
0ef2585
0660028
 
 
 
 
74e3b17
0660028
 
400662c
263af70
 
 
 
8071283
263af70
 
95c19d6
a3eda6f
8c49cb6
74e3b17
0660028
74e3b17
95c19d6
 
0811d37
69cf5b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import logging
import os

import gradio as gr
import numpy as np
import pandas as pd
import scipy.stats
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from huggingface_hub import HfApi

# Set up logging
logger = logging.getLogger("app")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

# Disable the absl logger (annoying)
logging.getLogger("absl").setLevel(logging.WARNING)

API = HfApi(token=os.environ.get("TOKEN"))
RESULTS_REPO = "open-rl-leaderboard/results_v2"
REFRESH_RATE = 5 * 60  # 5 minutes
ALL_ENV_IDS = {
    "Atari": [
        "AdventureNoFrameskip-v4",
        "AirRaidNoFrameskip-v4",
        "AlienNoFrameskip-v4",
        "AmidarNoFrameskip-v4",
        "AssaultNoFrameskip-v4",
        "AsterixNoFrameskip-v4",
        "AsteroidsNoFrameskip-v4",
        "AtlantisNoFrameskip-v4",
        "BankHeistNoFrameskip-v4",
        "BattleZoneNoFrameskip-v4",
        "BeamRiderNoFrameskip-v4",
        "BerzerkNoFrameskip-v4",
        "BowlingNoFrameskip-v4",
        "BoxingNoFrameskip-v4",
        "BreakoutNoFrameskip-v4",
        "CarnivalNoFrameskip-v4",
        "CentipedeNoFrameskip-v4",
        "ChopperCommandNoFrameskip-v4",
        "CrazyClimberNoFrameskip-v4",
        "DefenderNoFrameskip-v4",
        "DemonAttackNoFrameskip-v4",
        "DoubleDunkNoFrameskip-v4",
        "ElevatorActionNoFrameskip-v4",
        "EnduroNoFrameskip-v4",
        "FishingDerbyNoFrameskip-v4",
        "FreewayNoFrameskip-v4",
        "FrostbiteNoFrameskip-v4",
        "GopherNoFrameskip-v4",
        "GravitarNoFrameskip-v4",
        "HeroNoFrameskip-v4",
        "IceHockeyNoFrameskip-v4",
        "JamesbondNoFrameskip-v4",
        "JourneyEscapeNoFrameskip-v4",
        "KangarooNoFrameskip-v4",
        "KrullNoFrameskip-v4",
        "KungFuMasterNoFrameskip-v4",
        "MontezumaRevengeNoFrameskip-v4",
        "MsPacmanNoFrameskip-v4",
        "NameThisGameNoFrameskip-v4",
        "PhoenixNoFrameskip-v4",
        "PitfallNoFrameskip-v4",
        "PongNoFrameskip-v4",
        "PooyanNoFrameskip-v4",
        "PrivateEyeNoFrameskip-v4",
        "QbertNoFrameskip-v4",
        "RiverraidNoFrameskip-v4",
        "RoadRunnerNoFrameskip-v4",
        "RobotankNoFrameskip-v4",
        "SeaquestNoFrameskip-v4",
        "SkiingNoFrameskip-v4",
        "SolarisNoFrameskip-v4",
        "SpaceInvadersNoFrameskip-v4",
        "StarGunnerNoFrameskip-v4",
        "TennisNoFrameskip-v4",
        "TimePilotNoFrameskip-v4",
        "TutankhamNoFrameskip-v4",
        "UpNDownNoFrameskip-v4",
        "VentureNoFrameskip-v4",
        "VideoPinballNoFrameskip-v4",
        "WizardOfWorNoFrameskip-v4",
        "YarsRevengeNoFrameskip-v4",
        "ZaxxonNoFrameskip-v4",
    ],
    "Box2D": [
        "BipedalWalker-v3",
        "BipedalWalkerHardcore-v3",
        "CarRacing-v2",
        "LunarLander-v2",
        "LunarLanderContinuous-v2",
    ],
    "Toy text": [
        "Blackjack-v1",
        "CliffWalking-v0",
        "FrozenLake-v1",
        "FrozenLake8x8-v1",
    ],
    "Classic control": [
        "Acrobot-v1",
        "CartPole-v1",
        "MountainCar-v0",
        "MountainCarContinuous-v0",
        "Pendulum-v1",
    ],
    "MuJoCo": [
        "Ant-v4",
        "HalfCheetah-v4",
        "Hopper-v4",
        "Humanoid-v4",
        "HumanoidStandup-v4",
        "InvertedDoublePendulum-v4",
        "InvertedPendulum-v4",
        "Pusher-v4",
        "Reacher-v4",
        "Swimmer-v4",
        "Walker2d-v4",
    ],
    "PyBullet": [
        "AntBulletEnv-v0",
        "HalfCheetahBulletEnv-v0",
        "HopperBulletEnv-v0",
        "HumanoidBulletEnv-v0",
        "InvertedDoublePendulumBulletEnv-v0",
        "InvertedPendulumSwingupBulletEnv-v0",
        "MinitaurBulletEnv-v0",
        "ReacherBulletEnv-v0",
        "Walker2DBulletEnv-v0",
    ],
}


def iqm(x):
    return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)


def get_leaderboard_df():
    logger.info("Downloading results")
    dataset = load_dataset(RESULTS_REPO, split="train")  # split is not important, but we need to use "train")
    df = dataset.to_pandas()  # convert to pandas dataframe
    df = df[df["status"] == "DONE"]  # keep only the models that are done
    df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm)
    logger.debug("Results downloaded")
    return df


def select_env(df: pd.DataFrame, env_id: str):
    df = df[df["env_id"] == env_id]
    df = df.sort_values("iqm_episodic_return", ascending=False)
    df["ranking"] = np.arange(1, len(df) + 1)
    return df


def format_df(df: pd.DataFrame):
    # Add hyperlinks
    df = df.copy()
    for index, row in df.iterrows():
        user_id = row["user_id"]
        model_id = row["model_id"]
        df.loc[index, "user_id"] = f"[{user_id}](https://huggingface.co/{user_id})"
        df.loc[index, "model_id"] = f"[{model_id}](https://huggingface.co/{user_id}/{model_id})"

    # Keep only the relevant columns
    df = df[["ranking", "user_id", "model_id", "iqm_episodic_return"]]
    return df.values.tolist()


def refresh_video(df, env_id):
    env_df = select_env(df, env_id)
    if not env_df.empty:
        user_id = env_df.iloc[0]["user_id"]
        model_id = env_df.iloc[0]["model_id"]
        sha = env_df.iloc[0]["sha"]
        repo_id = f"{user_id}/{model_id}"
        try:
            video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model")
            return video_path
        except Exception as e:
            logger.error(f"Error while downloading video for {env_id}: {e}")
            return None
    else:
        return None


def refresh_one_video(df, env_id):
    def inner():
        return refresh_video(df, env_id)

    return inner


def refresh_winner(df, env_id):
    # print("Refreshing winners")
    env_df = select_env(df, env_id)
    if not env_df.empty:
        user_id = env_df.iloc[0]["user_id"]
        model_id = env_df.iloc[0]["model_id"]
        url = f"https://huggingface.co/{user_id}/{model_id}"
        return f"""## {env_id}

### πŸ† [Best model]({url}) πŸ†"""
    else:
        return f"""## {env_id}

This leaderboard is quite empty... 😒

Be the first to submit your model!
Check the tab "πŸš€ Getting my agent evaluated"
"""


def refresh_num_models(df):
    return f"The leaderboard currently contains {len(df):,} models."


css = """
.generating {
    border: none;
}
h2 {
    text-align: center;
}
h3 {
    text-align: center;
}

"""


def update_globals():
    global dataframes, winner_texts, video_pathes, num_models_str, df
    df = get_leaderboard_df()
    all_env_ids = [env_id for env_ids in ALL_ENV_IDS.values() for env_id in env_ids]
    dataframes = {env_id: format_df(select_env(df, env_id)) for env_id in all_env_ids}
    winner_texts = {env_id: refresh_winner(df, env_id) for env_id in all_env_ids}
    video_pathes = {env_id: refresh_video(df, env_id) for env_id in all_env_ids}
    num_models_str = refresh_num_models(df)


update_globals()


def refresh():
    global dataframes, winner_texts, num_models_str
    return list(dataframes.values()) + list(winner_texts.values()) + [num_models_str]


with gr.Blocks(css=css) as demo:
    with open("texts/heading.md") as fp:
        gr.Markdown(fp.read())
        num_models_md = gr.Markdown()
    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("πŸ… Leaderboard"):
            all_gr_dfs = {}
            all_gr_winners = {}
            all_gr_videos = {}
            for env_domain, env_ids in ALL_ENV_IDS.items():
                with gr.TabItem(env_domain):
                    for env_id in env_ids:
                        # If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
                        tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
                        with gr.TabItem(tab_env_id) as tab:
                            logger.debug(f"Creating tab for {env_id}")
                            with gr.Row(equal_height=False):
                                with gr.Column(scale=3):
                                    gr_df = gr.components.Dataframe(
                                        headers=["πŸ†", "πŸ§‘ User", "πŸ€– Model id", "πŸ“Š IQM episodic return"],
                                        datatype=["number", "markdown", "markdown", "number"],
                                    )
                                with gr.Column(scale=1):
                                    with gr.Row():  # Display the env_id and the winner
                                        gr_winner = gr.Markdown()
                                    with gr.Row():  # Play the video of the best model
                                        gr_video = gr.PlayableVideo(  # Doesn't loop for the moment, see https://github.com/gradio-app/gradio/issues/7689,
                                            min_width=50,
                                            show_download_button=False,
                                            show_share_button=False,
                                            show_label=False,
                                            interactive=False,
                                        )

                            all_gr_dfs[env_id] = gr_df
                            all_gr_winners[env_id] = gr_winner
                            all_gr_videos[env_id] = gr_video

                            tab.select(refresh_one_video(df, env_id), outputs=[gr_video])

                # Load the first video of the first environment
                demo.load(refresh_one_video(df, env_ids[0]), outputs=[all_gr_videos[env_ids[0]]])

        with gr.TabItem("πŸš€ Getting my agent evaluated"):
            with open("texts/getting_my_agent_evaluated.md") as fp:
                gr.Markdown(fp.read())

        with gr.TabItem("πŸ“ About"):
            with open("texts/about.md") as fp:
                gr.Markdown(fp.read())

    demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])

scheduler = BackgroundScheduler()
scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
scheduler.start()


if __name__ == "__main__":
    demo.queue().launch()