AlexNijjar commited on
Commit
45713ec
1 Parent(s): 94052c1

Implement validator source selection

Browse files
Files changed (1) hide show
  1. app.py +100 -70
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import time
3
  from dataclasses import dataclass
@@ -6,20 +7,20 @@ from zoneinfo import ZoneInfo
6
 
7
  import gradio as gr
8
  import plotly.graph_objects as go
9
- import schedule
10
  import wandb
11
  from substrateinterface import Keypair
12
- from wandb.apis.public import Run, Runs
13
 
14
- wandb_api = wandb.Api()
15
- demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
16
-
17
- SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
18
  WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
19
-
 
20
  BASELINE = 0.0
21
  GRAPH_HISTORY_DAYS = 30
22
- MAX_GRAPH_ENTRIES = 5
 
 
 
 
23
 
24
 
25
  @dataclass
@@ -59,13 +60,10 @@ def is_valid_run(run: Run):
59
  return False
60
 
61
 
62
- def get_graph_entries(runs: Runs) -> dict[int, GraphEntry]:
63
- graph_entries: dict[int, GraphEntry] = {}
64
 
65
  for run in reversed(runs[:GRAPH_HISTORY_DAYS]):
66
- if not is_valid_run(run):
67
- continue
68
-
69
  date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
70
 
71
  for key, value in run.summary.items():
@@ -76,24 +74,25 @@ def get_graph_entries(runs: Runs) -> dict[int, GraphEntry]:
76
  score = value["score"]
77
  model = value["model"]
78
 
79
- if uid not in graph_entries:
80
- graph_entries[uid] = GraphEntry([date], [score], [model], score)
81
  else:
82
- if score > graph_entries[uid].max_score:
83
- graph_entries[uid].max_score = score
84
 
85
- data = graph_entries[uid]
86
  data.dates.append(date)
87
  data.scores.append(data.max_score)
88
  data.models.append(model)
89
 
90
- return dict(sorted(graph_entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES])
91
 
92
 
93
- def create_graph(graph_entries: dict[int, GraphEntry]):
 
94
  fig = go.Figure()
95
 
96
- for uid, data in graph_entries.items():
97
  fig.add_trace(go.Scatter(
98
  x=data.dates,
99
  y=data.scores,
@@ -107,7 +106,7 @@ def create_graph(graph_entries: dict[int, GraphEntry]):
107
  ),
108
  ))
109
 
110
- date_range = max(graph_entries.values(), key=lambda entry: len(entry.dates)).dates
111
 
112
  fig.add_trace(go.Scatter(
113
  x=date_range,
@@ -127,62 +126,89 @@ def create_graph(graph_entries: dict[int, GraphEntry]):
127
  template="plotly_dark"
128
  )
129
 
130
- gr.Plot(fig)
131
 
132
 
133
- def refresh_leaderboard():
134
- now = datetime.now(tz=ZoneInfo("America/New_York"))
135
- print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}")
136
 
137
- demo.clear()
 
 
 
 
138
 
139
- with demo:
140
- with gr.Accordion("Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"):
141
- runs: Runs = wandb_api.runs(
142
- WANDB_RUN_PATH,
143
- filters={"config.type": "validator", "config.uid": SOURCE_VALIDATOR_UID},
144
- order="-created_at",
145
- )
 
 
 
 
 
 
 
 
146
 
147
- entries: dict[int, LeaderboardEntry] = {}
 
148
 
149
- for run in runs:
150
- if not is_valid_run(run):
151
- continue
 
152
 
153
- has_data = False
154
- for key, value in run.summary.items():
155
- if key.startswith("_"):
156
- continue
157
 
158
- has_data = True
159
 
160
- try:
161
- uid = int(key)
 
 
162
 
163
- entries[uid] = LeaderboardEntry(
164
- uid=uid,
165
- rank=value["rank"],
166
- model=value["model"],
167
- score=value["score"],
168
- hotkey=value["hotkey"],
169
- previous_day_winner=value["multiday_winner"],
170
- )
171
- except Exception:
172
- continue
173
 
174
- if has_data:
175
- break
 
 
 
 
 
 
 
 
 
 
176
 
177
- leaderboard: list[tuple] = [
178
- (entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner)
179
- for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
180
- ]
181
 
182
- create_graph(get_graph_entries(runs))
183
 
184
- gr.components.Dataframe(
185
- value=leaderboard,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
187
  datatype=["number", "number", "markdown", "number", "markdown", "bool"],
188
  elem_id="leaderboard-table",
@@ -190,16 +216,20 @@ def refresh_leaderboard():
190
  visible=True,
191
  )
192
 
 
 
 
 
193
 
194
- def main():
195
- refresh_leaderboard()
196
- schedule.every(30).minutes.do(refresh_leaderboard)
197
 
 
 
198
  demo.launch(prevent_thread_lock=True)
199
 
200
  while True:
201
- schedule.run_pending()
202
- time.sleep(1)
203
 
 
 
204
 
205
- main()
 
1
+ import json
2
  import os
3
  import time
4
  from dataclasses import dataclass
 
7
 
8
  import gradio as gr
9
  import plotly.graph_objects as go
 
10
  import wandb
11
  from substrateinterface import Keypair
12
+ from wandb.apis.public import Run
13
 
 
 
 
 
14
  WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
15
+ SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
16
+ REFRESH_RATE = 60 * 30 # 30 minutes
17
  BASELINE = 0.0
18
  GRAPH_HISTORY_DAYS = 30
19
+ MAX_GRAPH_ENTRIES = 10
20
+
21
+ wandb_api = wandb.Api()
22
+ demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
23
+ runs: dict[int, list[Run]] = {}
24
 
25
 
26
  @dataclass
 
60
  return False
61
 
62
 
63
+ def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
64
+ entries: dict[int, GraphEntry] = {}
65
 
66
  for run in reversed(runs[:GRAPH_HISTORY_DAYS]):
 
 
 
67
  date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
68
 
69
  for key, value in run.summary.items():
 
74
  score = value["score"]
75
  model = value["model"]
76
 
77
+ if uid not in entries:
78
+ entries[uid] = GraphEntry([date], [score], [model], score)
79
  else:
80
+ if score > entries[uid].max_score:
81
+ entries[uid].max_score = score
82
 
83
+ data = entries[uid]
84
  data.dates.append(date)
85
  data.scores.append(data.max_score)
86
  data.models.append(model)
87
 
88
+ return dict(sorted(entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES])
89
 
90
 
91
+ def create_graph(runs: list[Run]) -> go.Figure:
92
+ entries = get_graph_entries(runs)
93
  fig = go.Figure()
94
 
95
+ for uid, data in entries.items():
96
  fig.add_trace(go.Scatter(
97
  x=data.dates,
98
  y=data.scores,
 
106
  ),
107
  ))
108
 
109
+ date_range = max(entries.values(), key=lambda entry: len(entry.dates)).dates
110
 
111
  fig.add_trace(go.Scatter(
112
  x=date_range,
 
126
  template="plotly_dark"
127
  )
128
 
129
+ return fig
130
 
131
 
132
+ def create_leaderboard(runs: list[Run]) -> list[tuple]:
133
+ entries: dict[int, LeaderboardEntry] = {}
 
134
 
135
+ for run in runs:
136
+ has_data = False
137
+ for key, value in run.summary.items():
138
+ if key.startswith("_"):
139
+ continue
140
 
141
+ has_data = True
142
+
143
+ try:
144
+ uid = int(key)
145
+
146
+ entries[uid] = LeaderboardEntry(
147
+ uid=uid,
148
+ rank=value["rank"],
149
+ model=value["model"],
150
+ score=value["score"],
151
+ hotkey=value["hotkey"],
152
+ previous_day_winner=value["multiday_winner"],
153
+ )
154
+ except Exception:
155
+ continue
156
 
157
+ if has_data:
158
+ break
159
 
160
+ leaderboard: list[tuple] = [
161
+ (entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner)
162
+ for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
163
+ ]
164
 
165
+ return leaderboard
 
 
 
166
 
 
167
 
168
+ def get_run_validator_uid(run: Run) -> int:
169
+ json_config = json.loads(run.json_config)
170
+ uid = int(json_config["uid"]["value"])
171
+ return uid
172
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ def fetch_wandb_data():
175
+ wandb_runs = wandb_api.runs(
176
+ WANDB_RUN_PATH,
177
+ filters={"config.type": "validator"},
178
+ order="-created_at",
179
+ )
180
+
181
+ global runs
182
+ runs.clear()
183
+ for run in wandb_runs:
184
+ if not is_valid_run(run):
185
+ continue
186
 
187
+ uid = get_run_validator_uid(run)
188
+ if uid not in runs:
189
+ runs[uid] = []
190
+ runs[uid].append(run)
191
 
192
+ runs = dict(sorted(runs.items(), key=lambda item: item[0]))
193
 
194
+
195
+ def refresh():
196
+ fetch_wandb_data()
197
+ demo.clear()
198
+ with demo:
199
+ with gr.Accordion("Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"):
200
+ choices = list(runs.keys())
201
+ dropdown = gr.Dropdown(
202
+ choices,
203
+ value=SOURCE_VALIDATOR_UID,
204
+ interactive=True,
205
+ label="Source Validator"
206
+ )
207
+
208
+ graph = gr.Plot()
209
+
210
+ leaderboard = gr.components.Dataframe(
211
+ create_leaderboard(runs[dropdown.value]),
212
  headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
213
  datatype=["number", "number", "markdown", "number", "markdown", "bool"],
214
  elem_id="leaderboard-table",
 
216
  visible=True,
217
  )
218
 
219
+ demo.load(lambda uid: create_graph(runs[uid]), [dropdown], [graph])
220
+
221
+ dropdown.change(lambda uid: create_graph(runs[uid]), [dropdown], [graph])
222
+ dropdown.change(lambda uid: create_leaderboard(runs[uid]), [dropdown], [leaderboard])
223
 
 
 
 
224
 
225
+ if __name__ == "__main__":
226
+ refresh()
227
  demo.launch(prevent_thread_lock=True)
228
 
229
  while True:
230
+ time.sleep(REFRESH_RATE)
 
231
 
232
+ now = datetime.now(tz=ZoneInfo("America/New_York"))
233
+ print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}")
234
 
235
+ refresh()