AlexNijjar commited on
Commit
9764560
1 Parent(s): 1912048

Switch to generation time

Browse files
Files changed (1) hide show
  1. app.py +49 -19
app.py CHANGED
@@ -10,11 +10,13 @@ import plotly.graph_objects as go
10
  import wandb
11
  from substrateinterface import Keypair
12
  from wandb.apis.public import Run
 
13
 
14
  WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
15
  SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
 
 
16
  REFRESH_RATE = 60 * 30 # 30 minutes
17
- BASELINE = 0.0
18
  GRAPH_HISTORY_DAYS = 30
19
  MAX_GRAPH_ENTRIES = 10
20
 
@@ -28,6 +30,8 @@ class LeaderboardEntry:
28
  uid: int
29
  model: str
30
  score: float
 
 
31
  hotkey: str
32
  previous_day_winner: bool
33
  rank: int
@@ -36,9 +40,11 @@ class LeaderboardEntry:
36
  @dataclass
37
  class GraphEntry:
38
  dates: list[datetime]
 
 
39
  scores: list[float]
40
  models: list[str]
41
- max_score: float
42
 
43
 
44
  def is_valid_run(run: Run):
@@ -60,6 +66,13 @@ def is_valid_run(run: Run):
60
  return False
61
 
62
 
 
 
 
 
 
 
 
63
  def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
64
  entries: dict[int, GraphEntry] = {}
65
 
@@ -70,22 +83,30 @@ def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
70
  if key.startswith("_"):
71
  continue
72
 
 
 
 
73
  uid = int(key)
74
- score = value["score"]
 
 
 
75
  model = value["model"]
76
 
77
  if uid not in entries:
78
- entries[uid] = GraphEntry([date], [score], [model], score)
79
  else:
80
- if score > entries[uid].max_score:
81
- entries[uid].max_score = score
82
 
83
  data = entries[uid]
84
  data.dates.append(date)
85
- data.scores.append(data.max_score)
 
 
86
  data.models.append(model)
87
 
88
- return dict(sorted(entries.items(), key=lambda entry: entry[1].max_score, reverse=True)[:MAX_GRAPH_ENTRIES])
89
 
90
 
91
  def create_graph(runs: list[Run]) -> go.Figure:
@@ -95,14 +116,16 @@ def create_graph(runs: list[Run]) -> go.Figure:
95
  for uid, data in entries.items():
96
  fig.add_trace(go.Scatter(
97
  x=data.dates,
98
- y=data.scores,
99
- customdata=data.models,
100
  mode="lines+markers",
101
  name=uid,
102
  hovertemplate=(
103
  "<b>Date:</b> %{x|%Y-%m-%d}<br>" +
104
- "<b>Score:</b> %{y}<br>" +
105
- "<b>Model:</b> %{customdata}<br>"
 
 
106
  ),
107
  ))
108
 
@@ -110,7 +133,7 @@ def create_graph(runs: list[Run]) -> go.Figure:
110
 
111
  fig.add_trace(go.Scatter(
112
  x=date_range,
113
- y=[BASELINE] * len(date_range),
114
  line=dict(color="#ff0000", width=3),
115
  mode="lines",
116
  name="Baseline",
@@ -119,8 +142,8 @@ def create_graph(runs: list[Run]) -> go.Figure:
119
  background_color = gr.themes.default.colors.slate.c800
120
 
121
  fig.update_layout(
122
- title="Score Improvements",
123
- yaxis_title="Score",
124
  plot_bgcolor=background_color,
125
  paper_bgcolor=background_color,
126
  template="plotly_dark"
@@ -142,12 +165,17 @@ def create_leaderboard(runs: list[Run]) -> list[tuple]:
142
 
143
  try:
144
  uid = int(key)
 
 
 
145
 
146
  entries[uid] = LeaderboardEntry(
147
  uid=uid,
148
  rank=value["rank"],
149
  model=value["model"],
150
- score=value["score"],
 
 
151
  hotkey=value["hotkey"],
152
  previous_day_winner=value["multiday_winner"],
153
  )
@@ -158,7 +186,8 @@ def create_leaderboard(runs: list[Run]) -> list[tuple]:
158
  break
159
 
160
  leaderboard: list[tuple] = [
161
- (entry.rank + 1, entry.uid, entry.model, entry.score, entry.hotkey, entry.previous_day_winner)
 
162
  for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
163
  ]
164
 
@@ -227,8 +256,9 @@ def refresh():
227
 
228
  leaderboard = gr.components.Dataframe(
229
  create_leaderboard(runs[dropdown.value]),
230
- headers=["Rank", "Uid", "Model", "Score", "Hotkey", "Previous day winner"],
231
- datatype=["number", "number", "markdown", "number", "markdown", "bool"],
 
232
  elem_id="leaderboard-table",
233
  interactive=False,
234
  visible=True,
 
10
  import wandb
11
  from substrateinterface import Keypair
12
  from wandb.apis.public import Run
13
+ import numpy as np
14
 
15
  WANDB_RUN_PATH = os.environ["WANDB_RUN_PATH"]
16
  SOURCE_VALIDATOR_UID = int(os.environ["SOURCE_VALIDATOR_UID"])
17
+ BASELINE_AVERAGE = float(os.environ["BASELINE_AVERAGE"])
18
+
19
  REFRESH_RATE = 60 * 30 # 30 minutes
 
20
  GRAPH_HISTORY_DAYS = 30
21
  MAX_GRAPH_ENTRIES = 10
22
 
 
30
  uid: int
31
  model: str
32
  score: float
33
+ model_average: float
34
+ similarity: float
35
  hotkey: str
36
  previous_day_winner: bool
37
  rank: int
 
40
  @dataclass
41
  class GraphEntry:
42
  dates: list[datetime]
43
+ generation_times: list[float]
44
+ similarities: list[float]
45
  scores: list[float]
46
  models: list[str]
47
+ best_time: float
48
 
49
 
50
  def is_valid_run(run: Run):
 
66
  return False
67
 
68
 
69
+ def calculate_score(model_average: float, similarity: float) -> float:
70
+ return max(
71
+ 0.0,
72
+ BASELINE_AVERAGE - model_average
73
+ ) * similarity
74
+
75
+
76
  def get_graph_entries(runs: list[Run]) -> dict[int, GraphEntry]:
77
  entries: dict[int, GraphEntry] = {}
78
 
 
83
  if key.startswith("_"):
84
  continue
85
 
86
+ if "score" in value:
87
+ continue
88
+
89
  uid = int(key)
90
+
91
+ generation_time = value["generation_time"]
92
+ similarity = min(1, value["similarity"])
93
+ score = calculate_score(generation_time, similarity)
94
  model = value["model"]
95
 
96
  if uid not in entries:
97
+ entries[uid] = GraphEntry([date], [generation_time], [similarity], [score], [model], generation_time)
98
  else:
99
+ if generation_time < entries[uid].best_time:
100
+ entries[uid].best_time = generation_time
101
 
102
  data = entries[uid]
103
  data.dates.append(date)
104
+ data.generation_times.append(data.best_time)
105
+ data.similarities.append(similarity)
106
+ data.scores.append(score)
107
  data.models.append(model)
108
 
109
+ return dict(sorted(entries.items(), key=lambda entry: entry[1].best_time)[:MAX_GRAPH_ENTRIES])
110
 
111
 
112
  def create_graph(runs: list[Run]) -> go.Figure:
 
116
  for uid, data in entries.items():
117
  fig.add_trace(go.Scatter(
118
  x=data.dates,
119
+ y=data.generation_times,
120
+ customdata=np.stack((data.similarities, data.scores, data.models), axis=-1),
121
  mode="lines+markers",
122
  name=uid,
123
  hovertemplate=(
124
  "<b>Date:</b> %{x|%Y-%m-%d}<br>" +
125
+ "<b>Generation Time:</b> %{y}s<br>" +
126
+ "<b>Similarity:</b> %{customdata[0]}<br>" +
127
+ "<b>Score:</b> %{customdata[1]}<br>" +
128
+ "<b>Model:</b> %{customdata[2]}<br>"
129
  ),
130
  ))
131
 
 
133
 
134
  fig.add_trace(go.Scatter(
135
  x=date_range,
136
+ y=[BASELINE_AVERAGE] * len(date_range),
137
  line=dict(color="#ff0000", width=3),
138
  mode="lines",
139
  name="Baseline",
 
142
  background_color = gr.themes.default.colors.slate.c800
143
 
144
  fig.update_layout(
145
+ title="Generation Time Improvements",
146
+ yaxis_title="Generation Time (s)",
147
  plot_bgcolor=background_color,
148
  paper_bgcolor=background_color,
149
  template="plotly_dark"
 
165
 
166
  try:
167
  uid = int(key)
168
+ generation_time = value.get("generation_time", 0)
169
+ similarity = min(1, value.get("similarity", 0))
170
+ score = value.get("score", calculate_score(generation_time, similarity))
171
 
172
  entries[uid] = LeaderboardEntry(
173
  uid=uid,
174
  rank=value["rank"],
175
  model=value["model"],
176
+ score=score,
177
+ model_average=generation_time,
178
+ similarity=similarity,
179
  hotkey=value["hotkey"],
180
  previous_day_winner=value["multiday_winner"],
181
  )
 
186
  break
187
 
188
  leaderboard: list[tuple] = [
189
+ (entry.rank + 1, entry.uid, entry.model, entry.score, f"{entry.model_average:.3f}s", f"{entry.similarity:.3f}",
190
+ entry.hotkey, entry.previous_day_winner)
191
  for entry in sorted(entries.values(), key=lambda entry: (entry.score, entry.rank), reverse=True)
192
  ]
193
 
 
256
 
257
  leaderboard = gr.components.Dataframe(
258
  create_leaderboard(runs[dropdown.value]),
259
+ headers=["Rank", "Uid", "Model", "Score", "Generation Time", "Similarity", "Hotkey",
260
+ "Previous day winner"],
261
+ datatype=["number", "number", "markdown", "number", "markdown", "markdown", "markdown", "bool"],
262
  elem_id="leaderboard-table",
263
  interactive=False,
264
  visible=True,