AlexNijjar's picture
Use gradio's built-in refresh
90e1a0e verified
raw
history blame
11.8 kB
import json
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
import bittensor as bt
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import wandb
from substrateinterface import Keypair
from wandb.apis.public import Run
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
START_DATE = datetime(2024, 9, 17)
NET_UID = 39
REFRESH_RATE = 60 * 30 # 30 minutes
GRAPH_HISTORY_DAYS = 30
MAX_GRAPH_ENTRIES = 10
wandb_api = wandb.Api()
demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}", fill_height=True, fill_width=True)
subtensor = bt.subtensor()
metagraph = bt.metagraph(netuid=NET_UID)
bt.logging.disable_logging()
runs: dict[int, list[Run]] = {}
validator_identities: dict[int, str] = {}
last_refresh: datetime = datetime.now(tz=ZoneInfo("America/New_York"))
@dataclass
class LeaderboardEntry:
uid: int
winner: bool
repository: str
score: float
similarity: float
hotkey: str
baseline_generation_time: float
generation_time: float
size: int
vram_used: float
watts_used: float
@dataclass
class GraphEntry:
dates: list[datetime]
baseline_generation_times: list[float]
generation_times: list[float]
similarities: list[float]
scores: list[float]
models: list[str]
best_time: float
def is_valid_run(run: Run):
required_config_keys = ["hotkey", "uid", "contest", "signature"]
for key in required_config_keys:
if key not in run.config:
return False
uid = run.config["uid"]
validator_hotkey = run.config["hotkey"]
contest_name = run.config["contest"]
signing_message = f"{uid}:{validator_hotkey}:{contest_name}"
try:
return Keypair(validator_hotkey).verify(signing_message, run.config["signature"])
except Exception:
return False
def calculate_score(baseline_generation_time: float, generation_time: float, similarity_score: float) -> float:
return (baseline_generation_time - generation_time) * similarity_score
def date_from_run(run: Run) -> datetime:
return datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%SZ").astimezone(ZoneInfo("America/New_York"))
def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
entries: dict[int, GraphEntry] = {}
for run in reversed(runs[:GRAPH_HISTORY_DAYS]):
date = date_from_run(run)
for summary_key, summary_value in run.summary.items():
if not summary_key.startswith("benchmarks"):
continue
for key, value in summary_value.items():
if "score" in value:
continue
uid = int(key)
baseline_generation_time = value["baseline_generation_time"]
generation_time = value["generation_time"]
similarity = min(1, value["similarity"])
score = calculate_score(baseline_generation_time, generation_time, similarity)
model = run.summary["submissions"][str(uid)]["repository"]
if uid not in entries:
entries[uid] = GraphEntry([date], [baseline_generation_time], [generation_time], [similarity], [score], [model], generation_time)
else:
if generation_time < entries[uid].best_time:
entries[uid].best_time = generation_time
data = entries[uid]
data.dates.append(date)
data.baseline_generation_times.append(baseline_generation_time)
data.generation_times.append(data.best_time)
data.similarities.append(similarity)
data.scores.append(score)
data.models.append(model)
entries = dict(sorted(entries.items(), key=lambda entry: entry[1].scores, reverse=True)[:MAX_GRAPH_ENTRIES])
return dict(sorted(entries.items(), key=lambda entry: entry[1].best_time))
def create_graph(runs: list[Run]) -> go.Figure:
entries = get_graph_entries(runs)
fig = go.Figure()
for uid, data in entries.items():
fig.add_trace(go.Scatter(
x=data.dates,
y=data.generation_times,
customdata=np.stack((data.similarities, data.scores, data.models), axis=-1),
mode="lines+markers",
name=uid,
hovertemplate=(
"<b>Date:</b> %{x|%Y-%m-%d}<br>" +
"<b>Generation Time:</b> %{y}s<br>" +
"<b>Similarity:</b> %{customdata[0]}<br>" +
"<b>Score:</b> %{customdata[1]}<br>" +
"<b>Model:</b> %{customdata[2]}<br>"
),
))
date_range = max(entries.values(), key=lambda entry: len(entry.dates)).dates
average_baseline_generation_times = sum(entry.baseline_generation_times[0] for entry in entries.values()) / len(entries)
fig.add_trace(go.Scatter(
x=date_range,
y=[average_baseline_generation_times] * len(date_range),
line=dict(color="#ff0000", width=3),
mode="lines",
name="Baseline",
))
background_color = gr.themes.default.colors.slate.c800
fig.update_layout(
title="Generation Time Improvements",
yaxis_title="Generation Time (s)",
plot_bgcolor=background_color,
paper_bgcolor=background_color,
template="plotly_dark"
)
return fig
def get_run_validator_uid(run: Run) -> int:
json_config = json.loads(run.json_config)
uid = int(json_config["uid"]["value"])
return uid
def fetch_wandb_data():
wandb_runs = wandb_api.runs(
WANDB_RUN_PATH,
filters={"config.type": "validator", "created_at": {'$gt': str(START_DATE)}},
order="-created_at",
)
wandb_runs = [run for run in wandb_runs if "benchmarks" in run.summary]
global runs
runs.clear()
for run in wandb_runs:
if not is_valid_run(run):
continue
uid = get_run_validator_uid(run)
if not metagraph.validator_permit[uid]:
continue
if uid not in runs:
runs[uid] = []
runs[uid].append(run)
runs = dict(sorted(runs.items(), key=lambda item: item[0]))
def fetch_identities():
validator_identities.clear()
for uid in runs.keys():
identity = subtensor.substrate.query('SubtensorModule', 'Identities', [metagraph.coldkeys[uid]])
if identity != None:
validator_identities[uid] = identity.value["name"]
def get_validator_name(validator_uid: int) -> str:
if validator_uid in validator_identities:
return validator_identities[validator_uid]
else:
return metagraph.hotkeys[validator_uid]
def get_choices() -> list[tuple[str, int]]:
now = datetime.now(tz=ZoneInfo("America/New_York"))
noon = now.replace(hour=12, minute=0, second=0, microsecond=0)
if now.hour < 12:
noon -= timedelta(days=1)
choices: list[tuple[str, int]] = []
for uid, run in runs.items():
date = date_from_run(run[0])
if date < noon:
continue
benchmarks = dict(run[0].summary.get("benchmarks", {}))
finished = any("winner" in value for value in benchmarks.values())
progress_text = "Finished" if finished else "In Progress"
choices.append((f"{uid} - {get_validator_name(uid)} ({progress_text})", uid))
return choices
def refresh():
metagraph.sync(subtensor=subtensor)
fetch_wandb_data()
fetch_identities()
def get_data(validator_uid: int) -> gr.Dataframe:
global last_refresh
now = datetime.now(tz=ZoneInfo("America/New_York"))
if (now - last_refresh).total_seconds() > REFRESH_RATE:
refresh()
last_refresh = now
print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}")
entries: dict[int, LeaderboardEntry] = {}
for run in runs[validator_uid]:
has_data = False
for summary_key, summary_value in run.summary.items():
if not summary_key == "benchmarks":
continue
for key, value in summary_value.items():
has_data = True
uid = int(key)
generation_time = value["generation_time"]
baseline_generation_time = value["baseline_generation_time"]
similarity = min(1, value["similarity"])
entries[uid] = LeaderboardEntry(
uid=uid,
winner="winner" in value,
repository=run.summary["submissions"][str(uid)]["repository"],
score=calculate_score(baseline_generation_time, generation_time, similarity),
similarity=similarity,
baseline_generation_time=baseline_generation_time,
generation_time=generation_time,
size=value["size"],
vram_used=value["vram_used"],
watts_used=value["watts_used"],
hotkey=value["hotkey"],
)
if has_data:
break
sorted_entries = [(
entry.uid,
f"<span style='color: {'springgreen' if entry.winner else 'red'}'>{entry.winner}</span>",
entry.repository,
round(entry.score, 3),
f"{entry.generation_time:.3f}s",
f"{entry.similarity:.3f}",
f"{entry.size / 1_000_000_000:.3f}GB",
f"{entry.vram_used / 1_000_000_000:.3f}GB",
f"{entry.watts_used:.3f}W",
entry.hotkey,
) for entry in sorted(entries.values(), key=lambda entry: (entry.winner, entry.score), reverse=True)]
return gr.Dataframe(
sorted_entries,
headers=["Uid", "Winner", "Model", "Score", "Gen Time", "Similarity", "Size", "VRAM Usage", "Power Usage", "Hotkey"],
datatype=["number", "markdown", "markdown", "number", "markdown", "number", "markdown", "markdown", "markdown", "markdown"],
label=f"Last updated: {last_refresh.strftime('%Y-%m-%d %I:%M:%S %p')} EST",
interactive=False,
)
dropdown_value = SOURCE_VALIDATOR_UID
def set_checkbox_value(value: int):
global dropdown_value
dropdown_value = value
def main():
refresh()
with demo:
gr.Image(
"cover.png",
show_label=False,
show_download_button=False,
interactive=False,
show_fullscreen_button=False,
show_share_button=False,
container=False,
)
gr.Markdown(
"""
<center>
<h1 style="font-size: 50px"> SN39 EdgeMaxxing Leaderboard </h1>
This leaderboard for SN39 tracks the results and top model submissions from current and previous contests.
</center>
""")
with gr.Accordion(f"Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"):
dropdown = gr.Dropdown(
get_choices(),
value=SOURCE_VALIDATOR_UID,
interactive=True,
label="Source Validator"
)
table = get_data(dropdown.value)
table.attach_load_event(lambda _: get_data(dropdown_value), REFRESH_RATE, [table])
dropdown.change(lambda uid: get_data(uid), [dropdown], [table])
graph = gr.Plot()
graph.attach_load_event(lambda _: create_graph(runs[dropdown_value]), REFRESH_RATE, [graph])
dropdown.change(lambda uid: create_graph(runs[uid]), [dropdown], [graph])
dropdown.change(set_checkbox_value, [dropdown]) # TODO hacky
demo.queue().launch()
if __name__ == "__main__":
main()