AlexNijjar
commited on
Commit
•
45713ec
1
Parent(s):
94052c1
Implement validator source selection
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import time
|
3 |
from dataclasses import dataclass
|
@@ -6,20 +7,20 @@ from zoneinfo import ZoneInfo
|
|
6 |
|
7 |
import gradio as gr
|
8 |
import plotly.graph_objects as go
|
9 |
-
import schedule
|
10 |
import wandb
|
11 |
from substrateinterface import Keypair
|
12 |
-
from wandb.apis.public import Run
|
13 |
|
14 |
-
wandb_api = wandb.Api()
|
15 |
-
demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
|
16 |
-
|
17 |
-
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
|
18 |
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
|
19 |
-
|
|
|
20 |
BASELINE = 0.0
|
21 |
GRAPH_HISTORY_DAYS = 30
|
22 |
-
MAX_GRAPH_ENTRIES =
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
@dataclass
|
@@ -59,13 +60,10 @@ def is_valid_run(run: Run):
|
|
59 |
return False
|
60 |
|
61 |
|
62 |
-
def get_graph_entries(runs:
|
63 |
-
|
64 |
|
65 |
for run in reversed(runs[:GRAPH_HISTORY_DAYS]):
|
66 |
-
if not is_valid_run(run):
|
67 |
-
continue
|
68 |
-
|
69 |
date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
|
70 |
|
71 |
for key, value in run.summary.items():
|
@@ -76,24 +74,25 @@ def get_graph_entries(runs: Runs) -> dict[int, GraphEntry]:
|
|
76 |
score = value["score"]
|
77 |
model = value["model"]
|
78 |
|
79 |
-
if uid not in
|
80 |
-
|
81 |
else:
|
82 |
-
if score >
|
83 |
-
|
84 |
|
85 |
-
data =
|
86 |
data.dates.append(date)
|
87 |
data.scores.append(data.max_score)
|
88 |
data.models.append(model)
|
89 |
|
90 |
-
return dict(sorted(
|
91 |
|
92 |
|
93 |
-
def create_graph(
|
|
|
94 |
fig = go.Figure()
|
95 |
|
96 |
-
for uid, data in
|
97 |
fig.add_trace(go.Scatter(
|
98 |
x=data.dates,
|
99 |
y=data.scores,
|
@@ -107,7 +106,7 @@ def create_graph(graph_entries: dict[int, GraphEntry]):
|
|
107 |
),
|
108 |
))
|
109 |
|
110 |
-
date_range = max(
|
111 |
|
112 |
fig.add_trace(go.Scatter(
|
113 |
x=date_range,
|
@@ -127,62 +126,89 @@ def create_graph(graph_entries: dict[int, GraphEntry]):
|
|
127 |
template="plotly_dark"
|
128 |
)
|
129 |
|
130 |
-
|
131 |
|
132 |
|
133 |
-
def
|
134 |
-
|
135 |
-
print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
|
|
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
152 |
|
153 |
-
|
154 |
-
for key, value in run.summary.items():
|
155 |
-
if key.startswith("_"):
|
156 |
-
continue
|
157 |
|
158 |
-
has_data = True
|
159 |
|
160 |
-
|
161 |
-
|
|
|
|
|
162 |
|
163 |
-
entries[uid] = LeaderboardEntry(
|
164 |
-
uid=uid,
|
165 |
-
rank=value["rank"],
|
166 |
-
model=value["model"],
|
167 |
-
score=value["score"],
|
168 |
-
hotkey=value["hotkey"],
|
169 |
-
previous_day_winner=value["multiday_winner"],
|
170 |
-
)
|
171 |
-
except Exception:
|
172 |
-
continue
|
173 |
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
|
182 |
-
|
183 |
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
|
187 |
datatype=["number", "number", "markdown", "number", "markdown", "bool"],
|
188 |
elem_id="leaderboard-table",
|
@@ -190,16 +216,20 @@ def refresh_leaderboard():
|
|
190 |
visible=True,
|
191 |
)
|
192 |
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
def main():
|
195 |
-
refresh_leaderboard()
|
196 |
-
schedule.every(30).minutes.do(refresh_leaderboard)
|
197 |
|
|
|
|
|
198 |
demo.launch(prevent_thread_lock=True)
|
199 |
|
200 |
while True:
|
201 |
-
|
202 |
-
time.sleep(1)
|
203 |
|
|
|
|
|
204 |
|
205 |
-
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
import time
|
4 |
from dataclasses import dataclass
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import plotly.graph_objects as go
|
|
|
10 |
import wandb
|
11 |
from substrateinterface import Keypair
|
12 |
+
from wandb.apis.public import Run
|
13 |
|
|
|
|
|
|
|
|
|
14 |
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
|
15 |
+
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
|
16 |
+
REFRESH_RATE = 60 * 30 # 30 minutes
|
17 |
BASELINE = 0.0
|
18 |
GRAPH_HISTORY_DAYS = 30
|
19 |
+
MAX_GRAPH_ENTRIES = 10
|
20 |
+
|
21 |
+
wandb_api = wandb.Api()
|
22 |
+
demo = gr.Blocks(css=".typewriter {font-family: 'JMH Typewriter', sans-serif;}")
|
23 |
+
runs: dict[int, list[Run]] = {}
|
24 |
|
25 |
|
26 |
@dataclass
|
|
|
60 |
return False
|
61 |
|
62 |
|
63 |
+
def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
|
64 |
+
entries: dict[int, GraphEntry] = {}
|
65 |
|
66 |
for run in reversed(runs[:GRAPH_HISTORY_DAYS]):
|
|
|
|
|
|
|
67 |
date = datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
|
68 |
|
69 |
for key, value in run.summary.items():
|
|
|
74 |
score = value["score"]
|
75 |
model = value["model"]
|
76 |
|
77 |
+
if uid not in entries:
|
78 |
+
entries[uid] = GraphEntry([date], [score], [model], score)
|
79 |
else:
|
80 |
+
if score > entries[uid].max_score:
|
81 |
+
entries[uid].max_score = score
|
82 |
|
83 |
+
data = entries[uid]
|
84 |
data.dates.append(date)
|
85 |
data.scores.append(data.max_score)
|
86 |
data.models.append(model)
|
87 |
|
88 |
+
return dict(sorted(entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES])
|
89 |
|
90 |
|
91 |
+
def create_graph(runs: list[Run]) -> go.Figure:
|
92 |
+
entries = get_graph_entries(runs)
|
93 |
fig = go.Figure()
|
94 |
|
95 |
+
for uid, data in entries.items():
|
96 |
fig.add_trace(go.Scatter(
|
97 |
x=data.dates,
|
98 |
y=data.scores,
|
|
|
106 |
),
|
107 |
))
|
108 |
|
109 |
+
date_range = max(entries.values(), key=lambda entry: len(entry.dates)).dates
|
110 |
|
111 |
fig.add_trace(go.Scatter(
|
112 |
x=date_range,
|
|
|
126 |
template="plotly_dark"
|
127 |
)
|
128 |
|
129 |
+
return fig
|
130 |
|
131 |
|
132 |
+
def create_leaderboard(runs: list[Run]) -> list[tuple]:
|
133 |
+
entries: dict[int, LeaderboardEntry] = {}
|
|
|
134 |
|
135 |
+
for run in runs:
|
136 |
+
has_data = False
|
137 |
+
for key, value in run.summary.items():
|
138 |
+
if key.startswith("_"):
|
139 |
+
continue
|
140 |
|
141 |
+
has_data = True
|
142 |
+
|
143 |
+
try:
|
144 |
+
uid = int(key)
|
145 |
+
|
146 |
+
entries[uid] = LeaderboardEntry(
|
147 |
+
uid=uid,
|
148 |
+
rank=value["rank"],
|
149 |
+
model=value["model"],
|
150 |
+
score=value["score"],
|
151 |
+
hotkey=value["hotkey"],
|
152 |
+
previous_day_winner=value["multiday_winner"],
|
153 |
+
)
|
154 |
+
except Exception:
|
155 |
+
continue
|
156 |
|
157 |
+
if has_data:
|
158 |
+
break
|
159 |
|
160 |
+
leaderboard: list[tuple] = [
|
161 |
+
(entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner)
|
162 |
+
for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
|
163 |
+
]
|
164 |
|
165 |
+
return leaderboard
|
|
|
|
|
|
|
166 |
|
|
|
167 |
|
168 |
+
def get_run_validator_uid(run: Run) -> int:
|
169 |
+
json_config = json.loads(run.json_config)
|
170 |
+
uid = int(json_config["uid"]["value"])
|
171 |
+
return uid
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
def fetch_wandb_data():
|
175 |
+
wandb_runs = wandb_api.runs(
|
176 |
+
WANDB_RUN_PATH,
|
177 |
+
filters={"config.type": "validator"},
|
178 |
+
order="-created_at",
|
179 |
+
)
|
180 |
+
|
181 |
+
global runs
|
182 |
+
runs.clear()
|
183 |
+
for run in wandb_runs:
|
184 |
+
if not is_valid_run(run):
|
185 |
+
continue
|
186 |
|
187 |
+
uid = get_run_validator_uid(run)
|
188 |
+
if uid not in runs:
|
189 |
+
runs[uid] = []
|
190 |
+
runs[uid].append(run)
|
191 |
|
192 |
+
runs = dict(sorted(runs.items(), key=lambda item: item[0]))
|
193 |
|
194 |
+
|
195 |
+
def refresh():
|
196 |
+
fetch_wandb_data()
|
197 |
+
demo.clear()
|
198 |
+
with demo:
|
199 |
+
with gr.Accordion("Contest #1 Submission Leader: New Dream SDXL on NVIDIA RTX 4090s"):
|
200 |
+
choices = list(runs.keys())
|
201 |
+
dropdown = gr.Dropdown(
|
202 |
+
choices,
|
203 |
+
value=SOURCE_VALIDATOR_UID,
|
204 |
+
interactive=True,
|
205 |
+
label="Source Validator"
|
206 |
+
)
|
207 |
+
|
208 |
+
graph = gr.Plot()
|
209 |
+
|
210 |
+
leaderboard = gr.components.Dataframe(
|
211 |
+
create_leaderboard(runs[dropdown.value]),
|
212 |
headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
|
213 |
datatype=["number", "number", "markdown", "number", "markdown", "bool"],
|
214 |
elem_id="leaderboard-table",
|
|
|
216 |
visible=True,
|
217 |
)
|
218 |
|
219 |
+
demo.load(lambda uid: create_graph(runs[uid]), [dropdown], [graph])
|
220 |
+
|
221 |
+
dropdown.change(lambda uid: create_graph(runs[uid]), [dropdown], [graph])
|
222 |
+
dropdown.change(lambda uid: create_leaderboard(runs[uid]), [dropdown], [leaderboard])
|
223 |
|
|
|
|
|
|
|
224 |
|
225 |
+
if __name__ == "__main__":
|
226 |
+
refresh()
|
227 |
demo.launch(prevent_thread_lock=True)
|
228 |
|
229 |
while True:
|
230 |
+
time.sleep(REFRESH_RATE)
|
|
|
231 |
|
232 |
+
now = datetime.now(tz=ZoneInfo("America/New_York"))
|
233 |
+
print(f"Refreshing Leaderboard at {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
234 |
|
235 |
+
refresh()
|