AlexNijjar
commited on
Commit
•
9764560
1
Parent(s):
1912048
Switch to generation time
Browse files
app.py
CHANGED
@@ -10,11 +10,13 @@ import plotly.graph_objects as go
|
|
10 |
import wandb
|
11 |
from substrateinterface import Keypair
|
12 |
from wandb.apis.public import Run
|
|
|
13 |
|
14 |
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
|
15 |
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
|
|
|
|
|
16 |
REFRESH_RATE = 60 * 30 # 30 minutes
|
17 |
-
BASELINE = 0.0
|
18 |
GRAPH_HISTORY_DAYS = 30
|
19 |
MAX_GRAPH_ENTRIES = 10
|
20 |
|
@@ -28,6 +30,8 @@ class LeaderboardEntry:
|
|
28 |
uid: int
|
29 |
model: str
|
30 |
score: float
|
|
|
|
|
31 |
hotkey: str
|
32 |
previous_day_winner: bool
|
33 |
rank: int
|
@@ -36,9 +40,11 @@ class LeaderboardEntry:
|
|
36 |
@dataclass
|
37 |
class GraphEntry:
|
38 |
dates: list[datetime]
|
|
|
|
|
39 |
scores: list[float]
|
40 |
models: list[str]
|
41 |
-
|
42 |
|
43 |
|
44 |
def is_valid_run(run: Run):
|
@@ -60,6 +66,13 @@ def is_valid_run(run: Run):
|
|
60 |
return False
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
|
64 |
entries: dict[int, GraphEntry] = {}
|
65 |
|
@@ -70,22 +83,30 @@ def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
|
|
70 |
if key.startswith("_"):
|
71 |
continue
|
72 |
|
|
|
|
|
|
|
73 |
uid = int(key)
|
74 |
-
|
|
|
|
|
|
|
75 |
model = value["model"]
|
76 |
|
77 |
if uid not in entries:
|
78 |
-
entries[uid] = GraphEntry([date], [score], [model],
|
79 |
else:
|
80 |
-
if
|
81 |
-
entries[uid].
|
82 |
|
83 |
data = entries[uid]
|
84 |
data.dates.append(date)
|
85 |
-
data.
|
|
|
|
|
86 |
data.models.append(model)
|
87 |
|
88 |
-
return dict(sorted(entries.items(), key=lambda entry: entry[1].
|
89 |
|
90 |
|
91 |
def create_graph(runs: list[Run]) -> go.Figure:
|
@@ -95,14 +116,16 @@ def create_graph(runs: list[Run]) -> go.Figure:
|
|
95 |
for uid, data in entries.items():
|
96 |
fig.add_trace(go.Scatter(
|
97 |
x=data.dates,
|
98 |
-
y=data.
|
99 |
-
customdata=data.models,
|
100 |
mode="lines+markers",
|
101 |
name=uid,
|
102 |
hovertemplate=(
|
103 |
"<b>Date:</b> %{x|%Y-%m-%d}<br>" +
|
104 |
-
"<b>
|
105 |
-
"<b>
|
|
|
|
|
106 |
),
|
107 |
))
|
108 |
|
@@ -110,7 +133,7 @@ def create_graph(runs: list[Run]) -> go.Figure:
|
|
110 |
|
111 |
fig.add_trace(go.Scatter(
|
112 |
x=date_range,
|
113 |
-
y=[
|
114 |
line=dict(color="#ff0000", width=3),
|
115 |
mode="lines",
|
116 |
name="Baseline",
|
@@ -119,8 +142,8 @@ def create_graph(runs: list[Run]) -> go.Figure:
|
|
119 |
background_color = gr.themes.default.colors.slate.c800
|
120 |
|
121 |
fig.update_layout(
|
122 |
-
title="
|
123 |
-
yaxis_title="
|
124 |
plot_bgcolor=background_color,
|
125 |
paper_bgcolor=background_color,
|
126 |
template="plotly_dark"
|
@@ -142,12 +165,17 @@ def create_leaderboard(runs: list[Run]) -> list[tuple]:
|
|
142 |
|
143 |
try:
|
144 |
uid = int(key)
|
|
|
|
|
|
|
145 |
|
146 |
entries[uid] = LeaderboardEntry(
|
147 |
uid=uid,
|
148 |
rank=value["rank"],
|
149 |
model=value["model"],
|
150 |
-
score=
|
|
|
|
|
151 |
hotkey=value["hotkey"],
|
152 |
previous_day_winner=value["multiday_winner"],
|
153 |
)
|
@@ -158,7 +186,8 @@ def create_leaderboard(runs: list[Run]) -> list[tuple]:
|
|
158 |
break
|
159 |
|
160 |
leaderboard: list[tuple] = [
|
161 |
-
(entry.rank + 1, entry.uid, entry.model, entry.score, entry.
|
|
|
162 |
for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
|
163 |
]
|
164 |
|
@@ -227,8 +256,9 @@ def refresh():
|
|
227 |
|
228 |
leaderboard = gr.components.Dataframe(
|
229 |
create_leaderboard(runs[dropdown.value]),
|
230 |
-
headers=["Rank", "Uid", "Model", "Score", "
|
231 |
-
|
|
|
232 |
elem_id="leaderboard-table",
|
233 |
interactive=False,
|
234 |
visible=True,
|
|
|
10 |
import wandb
|
11 |
from substrateinterface import Keypair
|
12 |
from wandb.apis.public import Run
|
13 |
+
import numpy as np
|
14 |
|
15 |
WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
|
16 |
SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
|
17 |
+
BASELINE_AVERAGE = float(os.environ["BASELINE_AVERAGE"])
|
18 |
+
|
19 |
REFRESH_RATE = 60 * 30 # 30 minutes
|
|
|
20 |
GRAPH_HISTORY_DAYS = 30
|
21 |
MAX_GRAPH_ENTRIES = 10
|
22 |
|
|
|
30 |
uid: int
|
31 |
model: str
|
32 |
score: float
|
33 |
+
model_average: float
|
34 |
+
similarity: float
|
35 |
hotkey: str
|
36 |
previous_day_winner: bool
|
37 |
rank: int
|
|
|
40 |
@dataclass
|
41 |
class GraphEntry:
|
42 |
dates: list[datetime]
|
43 |
+
generation_times: list[float]
|
44 |
+
similarities: list[float]
|
45 |
scores: list[float]
|
46 |
models: list[str]
|
47 |
+
best_time: float
|
48 |
|
49 |
|
50 |
def is_valid_run(run: Run):
|
|
|
66 |
return False
|
67 |
|
68 |
|
69 |
+
def calculate_score(model_average: float, similarity: float) -> float:
|
70 |
+
return max(
|
71 |
+
0.0,
|
72 |
+
BASELINE_AVERAGE - model_average
|
73 |
+
) * similarity
|
74 |
+
|
75 |
+
|
76 |
def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
|
77 |
entries: dict[int, GraphEntry] = {}
|
78 |
|
|
|
83 |
if key.startswith("_"):
|
84 |
continue
|
85 |
|
86 |
+
if "score" in value:
|
87 |
+
continue
|
88 |
+
|
89 |
uid = int(key)
|
90 |
+
|
91 |
+
generation_time = value["generation_time"]
|
92 |
+
similarity = min(1, value["similarity"])
|
93 |
+
score = calculate_score(generation_time, similarity)
|
94 |
model = value["model"]
|
95 |
|
96 |
if uid not in entries:
|
97 |
+
entries[uid] = GraphEntry([date], [generation_time], [similarity], [score], [model], generation_time)
|
98 |
else:
|
99 |
+
if generation_time < entries[uid].best_time:
|
100 |
+
entries[uid].best_time = generation_time
|
101 |
|
102 |
data = entries[uid]
|
103 |
data.dates.append(date)
|
104 |
+
data.generation_times.append(data.best_time)
|
105 |
+
data.similarities.append(similarity)
|
106 |
+
data.scores.append(score)
|
107 |
data.models.append(model)
|
108 |
|
109 |
+
return dict(sorted(entries.items(), key=lambda entry: entry[1].best_time)[:MAX_GRAPH_ENTRIES])
|
110 |
|
111 |
|
112 |
def create_graph(runs: list[Run]) -> go.Figure:
|
|
|
116 |
for uid, data in entries.items():
|
117 |
fig.add_trace(go.Scatter(
|
118 |
x=data.dates,
|
119 |
+
y=data.generation_times,
|
120 |
+
customdata=np.stack((data.similarities, data.scores, data.models), axis=-1),
|
121 |
mode="lines+markers",
|
122 |
name=uid,
|
123 |
hovertemplate=(
|
124 |
"<b>Date:</b> %{x|%Y-%m-%d}<br>" +
|
125 |
+
"<b>Generation Time:</b> %{y}s<br>" +
|
126 |
+
"<b>Similarity:</b> %{customdata[0]}<br>" +
|
127 |
+
"<b>Score:</b> %{customdata[1]}<br>" +
|
128 |
+
"<b>Model:</b> %{customdata[2]}<br>"
|
129 |
),
|
130 |
))
|
131 |
|
|
|
133 |
|
134 |
fig.add_trace(go.Scatter(
|
135 |
x=date_range,
|
136 |
+
y=[BASELINE_AVERAGE] * len(date_range),
|
137 |
line=dict(color="#ff0000", width=3),
|
138 |
mode="lines",
|
139 |
name="Baseline",
|
|
|
142 |
background_color = gr.themes.default.colors.slate.c800
|
143 |
|
144 |
fig.update_layout(
|
145 |
+
title="Generation Time Improvements",
|
146 |
+
yaxis_title="Generation Time (s)",
|
147 |
plot_bgcolor=background_color,
|
148 |
paper_bgcolor=background_color,
|
149 |
template="plotly_dark"
|
|
|
165 |
|
166 |
try:
|
167 |
uid = int(key)
|
168 |
+
generation_time = value.get("generation_time", 0)
|
169 |
+
similarity = min(1, value.get("similarity", 0))
|
170 |
+
score = value.get("score", calculate_score(generation_time, similarity))
|
171 |
|
172 |
entries[uid] = LeaderboardEntry(
|
173 |
uid=uid,
|
174 |
rank=value["rank"],
|
175 |
model=value["model"],
|
176 |
+
score=score,
|
177 |
+
model_average=generation_time,
|
178 |
+
similarity=similarity,
|
179 |
hotkey=value["hotkey"],
|
180 |
previous_day_winner=value["multiday_winner"],
|
181 |
)
|
|
|
186 |
break
|
187 |
|
188 |
leaderboard: list[tuple] = [
|
189 |
+
(entry.rank + 1, entry.uid, entry.model, entry.score, f"{entry.model_average:.3f}s", f"{entry.similarity:.3f}",
|
190 |
+
entry.hotkey, entry.previous_day_winner)
|
191 |
for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
|
192 |
]
|
193 |
|
|
|
256 |
|
257 |
leaderboard = gr.components.Dataframe(
|
258 |
create_leaderboard(runs[dropdown.value]),
|
259 |
+
headers=["Rank", "Uid", "Model", "Score", "Generation Time", "Similarity", "Hotkey",
|
260 |
+
"Previous day winner"],
|
261 |
+
datatype=["number", "number", "markdown", "number", "markdown", "markdown", "markdown", "bool"],
|
262 |
elem_id="leaderboard-table",
|
263 |
interactive=False,
|
264 |
visible=True,
|