import json import os import time from dataclasses import dataclass from datetime import datetime from zoneinfo import ZoneInfo import gradio as gr import plotly.graph_objects as go import wandb from substrateinterface import Keypair from wandb.apis.public import Run WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"] SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"]) REFRESH_RATE = 60 * 30 # 30 minutes BASELINE = 0.0 GRAPH_HISTORY_DAYS = 30 MAX_GRAPH_ENTRIES = 10 wandb_api = wandb.Api() demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}") runs: dict[int, list[Run]] = {} @dataclass class LeaderboardEntry: uid: int model: str score: float hotkey: str previous_day_winner: bool rank: int @dataclass class GraphEntry: dates: list[datetime] scores: list[float] models: list[str] max_score: float def is_valid_run(run: Run): required_config_keys = ["hotkey", "uid", "contest", "signature"] for key in required_config_keys: if key not in run.config: return False uid = run.config["uid"] validator_hotkey = run.config["hotkey"] contest_name = run.config["contest"] signing_message = f"{uid}:{validator_hotkey}:{contest_name}" try: return Keypair(validator_hotkey).verify(signing_message, run.config["signature"]) except Exception: return False def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]: entries: dict[int, GraphEntry] = {} for run in reversed(runs[:GRAPH_HISTORY_DAYS]): date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S") for key, value in run.summary.items(): if key.startswith("_"): continue uid = int(key) score = value["score"] model = value["model"] if uid not in entries: entries[uid] = GraphEntry([date], [score], [model], score) else: if score > entries[uid].max_score: entries[uid].max_score = score data = entries[uid] data.dates.append(date) data.scores.append(data.max_score) data.models.append(model) return dict(sorted(entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES]) def create_graph(runs: list[Run]) -> go.Figure: entries = get_graph_entries(runs) fig = go.Figure() for uid, data in entries.items(): fig.add_trace(go.Scatter( x=data.dates, y=data.scores, customdata=data.models, mode="lines+markers", name=uid, hovertemplate=( "Date: %{x|%Y-%m-%d}
" + "Score: %{y}
" + "Model: %{customdata}
" ), )) date_range = max(entries.values(), key=lambda entry: len(entry.dates)).dates fig.add_trace(go.Scatter( x=date_range, y=[BASELINE] * len(date_range), line=dict(color="#ff0000", width=3), mode="lines", name="Baseline", )) background_color = gr.themes.default.colors.slate.c800 fig.update_layout( title="Score Improvements", yaxis_title="Score", plot_bgcolor=background_color, paper_bgcolor=background_color, template="plotly_dark" ) return fig def create_leaderboard(runs: list[Run]) -> list[tuple]: entries: dict[int, LeaderboardEntry] = {} for run in runs: has_data = False for key, value in run.summary.items(): if key.startswith("_"): continue has_data = True try: uid = int(key) entries[uid] = LeaderboardEntry( uid=uid, rank=value["rank"], model=value["model"], score=value["score"], hotkey=value["hotkey"], previous_day_winner=value["multiday_winner"], ) except Exception: continue if has_data: break leaderboard: list[tuple] = [ (entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner) for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True) ] return leaderboard def get_run_validator_uid(run: Run) -> int: json_config = json.loads(run.json_config) uid = int(json_config["uid"]["value"]) return uid def fetch_wandb_data(): wandb_runs = wandb_api.runs( WANDB_RUN_PATH, filters={"config.type": "validator"}, order="-created_at", ) global runs runs.clear() for run in wandb_runs: if not is_valid_run(run): continue uid = get_run_validator_uid(run) if uid not in runs: runs[uid] = [] runs[uid].append(run) runs = dict(sorted(runs.items(), key=lambda item: item[0])) def refresh(): fetch_wandb_data() demo.clear() with demo: with gr.Accordion("Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"): choices = list(runs.keys()) dropdown = gr.Dropdown( choices, value=SOURCE_VALIDATOR_UID, interactive=True, label="Source Validator" ) graph = gr.Plot() leaderboard = gr.components.Dataframe( create_leaderboard(runs[dropdown.value]), headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"], datatype=["number", "number", "markdown", "number", "markdown", "bool"], elem_id="leaderboard-table", interactive=False, visible=True, ) demo.load(lambda uid: create_graph(runs[uid]), [dropdown], [graph]) dropdown.change(lambda uid: create_graph(runs[uid]), [dropdown], [graph]) dropdown.change(lambda uid: create_leaderboard(runs[uid]), [dropdown], [leaderboard]) if __name__ == "__main__": refresh() demo.launch(prevent_thread_lock=True) while True: time.sleep(REFRESH_RATE) now = datetime.now(tz=ZoneInfo("America/New_York")) print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}") refresh()