AlexNijjar commited on
Commit
75d58eb
1 Parent(s): d28c674

Implement leaderboard graph

Browse files

TODO: get actual baseline score

Files changed (2) hide show
  1. app.py +94 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
2
  import time
3
  from dataclasses import dataclass
 
 
4
 
5
  import gradio as gr
 
6
  import schedule
7
  import wandb
8
  from substrateinterface import Keypair
@@ -14,6 +17,9 @@ demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
14
  SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
15
  WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
16
 
 
 
 
17
 
18
  @dataclass
19
  class LeaderboardEntry:
@@ -25,6 +31,14 @@ class LeaderboardEntry:
25
  rank: int
26
 
27
 
 
 
 
 
 
 
 
 
28
  def is_valid_run(run: Run):
29
  required_config_keys = ["hotkey", "uid", "contest", "signature"]
30
 
@@ -44,8 +58,85 @@ def is_valid_run(run: Run):
44
  return False
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def refresh_leaderboard():
48
- print("Refreshing Leaderboard")
 
49
 
50
  demo.clear()
51
 
@@ -92,6 +183,8 @@ def refresh_leaderboard():
92
  for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
93
  ]
94
 
 
 
95
  gr.components.Dataframe(
96
  value=leaderboard,
97
  headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
 
1
  import os
2
  import time
3
  from dataclasses import dataclass
4
+ from datetime import datetime
5
+ from zoneinfo import ZoneInfo
6
 
7
  import gradio as gr
8
+ import plotly.graph_objects as go
9
  import schedule
10
  import wandb
11
  from substrateinterface import Keypair
 
17
  SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
18
  WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
19
 
20
+ GRAPH_HISTORY_DAYS = 30
21
+ MAX_GRAPH_ENTRIES = 5
22
+
23
 
24
  @dataclass
25
  class LeaderboardEntry:
 
31
  rank: int
32
 
33
 
34
+ @dataclass
35
+ class GraphEntry:
36
+ dates: list[datetime]
37
+ scores: list[float]
38
+ models: list[str]
39
+ max_score: float
40
+
41
+
42
  def is_valid_run(run: Run):
43
  required_config_keys = ["hotkey", "uid", "contest", "signature"]
44
 
 
58
  return False
59
 
60
 
61
+ def get_baseline():
62
+ return 0.1 # TODO replace with actual baseline
63
+
64
+
65
+ def get_graph_entries(runs: Runs) -> dict[int, GraphEntry]:
66
+ graph_entries: dict[int, GraphEntry] = {}
67
+
68
+ for run in reversed(runs[:GRAPH_HISTORY_DAYS]):
69
+ if not is_valid_run(run):
70
+ continue
71
+
72
+ date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
73
+
74
+ for key, value in run.summary.items():
75
+ if key.startswith("_"):
76
+ continue
77
+
78
+ uid = int(key)
79
+ score = value["score"]
80
+ model = value["model"]
81
+
82
+ if uid not in graph_entries:
83
+ graph_entries[uid] = GraphEntry([date], [score], [model], score)
84
+ else:
85
+ if score > graph_entries[uid].max_score:
86
+ graph_entries[uid].max_score = score
87
+
88
+ data = graph_entries[uid]
89
+ data.dates.append(date)
90
+ data.scores.append(data.max_score)
91
+ data.models.append(model)
92
+
93
+ return dict(sorted(graph_entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES])
94
+
95
+
96
+ def create_graph(graph_entries: dict[int, GraphEntry]):
97
+ fig = go.Figure()
98
+
99
+ for uid, data in graph_entries.items():
100
+ fig.add_trace(go.Scatter(
101
+ x=data.dates,
102
+ y=data.scores,
103
+ customdata=data.models,
104
+ mode="lines+markers",
105
+ name=uid,
106
+ hovertemplate=(
107
+ "<b>Date:</b> %{x|%Y-%m-%d}<br>" +
108
+ "<b>Score:</b> %{y}<br>" +
109
+ "<b>Model:</b> %{customdata}<br>"
110
+ ),
111
+ ))
112
+
113
+ date_range = max(graph_entries.values(), key=lambda entry: len(entry.dates)).dates
114
+
115
+ baseline = get_baseline()
116
+ fig.add_trace(go.Scatter(
117
+ x=date_range,
118
+ y=[baseline] * len(date_range),
119
+ line=dict(color="#ff0000", width=3),
120
+ mode="lines",
121
+ name="Baseline",
122
+ ))
123
+
124
+ background_color = gr.themes.default.colors.slate.c800
125
+
126
+ fig.update_layout(
127
+ title="Score Improvements",
128
+ yaxis_title="Score",
129
+ plot_bgcolor=background_color,
130
+ paper_bgcolor=background_color,
131
+ template="plotly_dark"
132
+ )
133
+
134
+ gr.Plot(fig)
135
+
136
+
137
  def refresh_leaderboard():
138
+ now = datetime.now(tz=ZoneInfo("America/New_York"))
139
+ print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}")
140
 
141
  demo.clear()
142
 
 
183
  for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
184
  ]
185
 
186
+ create_graph(get_graph_entries(runs))
187
+
188
  gr.components.Dataframe(
189
  value=leaderboard,
190
  headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
requirements.txt CHANGED
@@ -3,3 +3,4 @@ gradio
3
  wandb
4
  substrate-interface
5
  schedule
 
 
3
  wandb
4
  substrate-interface
5
  schedule
6
+ plotly