davidmezzetti commited on
Commit
4ee078c
1 Parent(s): cf62ef7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -36
app.py CHANGED
@@ -1,13 +1,14 @@
1
  """
2
  Baseball statistics application with txtai and Streamlit.
3
 
4
- Install txtai and streamlit to run:
5
  pip install txtai streamlit
6
  """
7
 
8
  import datetime
9
  import os
10
 
 
11
  import numpy as np
12
  import pandas as pd
13
  import streamlit as st
@@ -57,15 +58,12 @@ class Stats:
57
 
58
  raise NotImplementedError
59
 
60
- def sort(self, rows):
61
  """
62
- Sorts rows stored as a DataFrame.
63
-
64
- Args:
65
- rows: input DataFrame
66
 
67
  Returns:
68
- sorted DataFrame
69
  """
70
 
71
  raise NotImplementedError
@@ -116,30 +114,41 @@ class Stats:
116
  vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
117
  data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
118
 
119
- embeddings = Embeddings({
120
- "transform": self.transform,
121
- })
 
 
122
 
123
  embeddings.index((uid, vectors[uid], None) for uid in vectors)
124
 
125
  return vectors, data, embeddings
126
 
127
- def years(self, player):
128
  """
129
- Looks up the years active for a player along with the player's best statistical year.
130
 
131
  Args:
132
  player: player name
133
 
134
  Returns:
135
- start, end, best
136
  """
137
 
138
  if player in self.names:
139
- df = self.sort(self.stats[self.stats["playerID"] == self.names[player]])
140
- return int(df["yearID"].min()), int(df["yearID"].max()), int(df["yearID"].iloc[0])
 
 
 
 
 
 
141
 
142
- return 1871, datetime.datetime.today().year, 1950
 
 
 
143
 
144
  def search(self, player=None, year=None, row=None, limit=10):
145
  """
@@ -196,10 +205,42 @@ class Stats:
196
 
197
 
198
  class Batting(Stats):
 
 
 
 
199
  def loadcolumns(self):
200
  return [
201
- "birthMonth", "age", "weight", "height", "yearID", "G", "AB", "R", "H", "1B", "2B", "3B", "HR", "RBI", "SB", "CS",
202
- "BB", "SO", "IBB", "HBP", "SH", "SF", "GIDP", "POS", "AVG", "OBP", "TB", "SLG", "OPS", "OPS+"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  ]
204
 
205
  def load(self):
@@ -230,8 +271,8 @@ class Batting(Stats):
230
 
231
  return batting
232
 
233
- def sort(self, rows):
234
- return rows.sort_values(by="OPS+", ascending=False)
235
 
236
  def vector(self, row):
237
  row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"]
@@ -255,7 +296,7 @@ class Batting(Stats):
255
  """
256
 
257
  positions = {}
258
- for x, row in fielding.iterrows():
259
  uid = f'{row["yearID"]}{row["playerID"]}'
260
  position = row["POS"] if row["POS"] else 0
261
  if position == "P":
@@ -294,12 +335,46 @@ class Batting(Stats):
294
  uid = f'{row["yearID"]}{row["playerID"]}'
295
  return positions[uid][0] if uid in positions else 0
296
 
 
297
  class Pitching(Stats):
 
 
 
 
298
  def loadcolumns(self):
299
  return [
300
- "birthMonth", "age", "weight", "height", "yearID", "W", "L", "G", "GS", "CG", "SHO", "SV", "IPouts",
301
- "H", "ER", "HR", "BB", "SO", "BAOpp", "ERA", "IBB", "WP", "HBP", "BK", "BFP", "GF", "R", "SH", "SF",
302
- "GIDP", "WHIP", "WADJ"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  ]
304
 
305
  def load(self):
@@ -316,16 +391,16 @@ class Pitching(Stats):
316
  # Calculated columns
317
  pitching["age"] = pitching["yearID"] - pitching["birthYear"]
318
  pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3)
319
- pitching["WADJ"] =(pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"])
320
 
321
  return pitching
322
 
323
- def sort(self, rows):
324
- return rows.sort_values(by="WADJ", ascending=False)
325
 
326
  def vector(self, row):
327
  row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None
328
- row["WADJ"] =(row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None
329
 
330
  return self.transform(row)
331
 
@@ -352,13 +427,15 @@ class Application:
352
  """
353
 
354
  st.title("⚾ Baseball Statistics")
355
- st.markdown("""
 
356
  This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai).
357
  Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project.
358
- """)
 
359
 
360
  self.player()
361
-
362
  def player(self):
363
  """
364
  Player tab.
@@ -373,19 +450,59 @@ class Application:
373
  names = sorted(stats.names)
374
  player = st.selectbox("Player", names, names.index(default))
375
 
 
 
 
376
  # Player year
377
- start, end, best = stats.years(player)
378
- year = st.slider("Year", start, end, best) if start != end else start
 
 
 
379
 
380
  # Run search
381
  results = stats.search(player, year)
382
 
383
  # Display results
384
- self.display(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
- def display(self, results, columns):
387
  """
388
- Displays a list of results.
389
 
390
  Args:
391
  results: list of results
 
1
  """
2
  Baseball statistics application with txtai and Streamlit.
3
 
4
+ Install txtai and streamlit (>= 1.23) to run:
5
  pip install txtai streamlit
6
  """
7
 
8
  import datetime
9
  import os
10
 
11
+ import altair as alt
12
  import numpy as np
13
  import pandas as pd
14
  import streamlit as st
 
58
 
59
  raise NotImplementedError
60
 
61
+ def metric(self):
62
  """
63
+ Primary metric column.
 
 
 
64
 
65
  Returns:
66
+ metric column name
67
  """
68
 
69
  raise NotImplementedError
 
114
  vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
115
  data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
116
 
117
+ embeddings = Embeddings(
118
+ {
119
+ "transform": self.transform,
120
+ }
121
+ )
122
 
123
  embeddings.index((uid, vectors[uid], None) for uid in vectors)
124
 
125
  return vectors, data, embeddings
126
 
127
+ def metrics(self, player):
128
  """
129
+ Looks up a player's active years, best statistical year and key metrics.
130
 
131
  Args:
132
  player: player name
133
 
134
  Returns:
135
+ active, best, metrics
136
  """
137
 
138
  if player in self.names:
139
+ # Get player stats
140
+ stats = self.stats[self.stats["playerID"] == self.names[player]]
141
+
142
+ # Build key metrics
143
+ metrics = stats[["yearID", self.metric()]]
144
+
145
+ # Get best year, sort by primary metric
146
+ best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0])
147
 
148
+ # Get years active, best year, along with metric trends
149
+ return metrics["yearID"].tolist(), best, metrics
150
+
151
+ return range(1871, datetime.datetime.today().year), 1950, None
152
 
153
  def search(self, player=None, year=None, row=None, limit=10):
154
  """
 
205
 
206
 
207
  class Batting(Stats):
208
+ """
209
+ Batting stats.
210
+ """
211
+
212
  def loadcolumns(self):
213
  return [
214
+ "birthMonth",
215
+ "yearID",
216
+ "age",
217
+ "height",
218
+ "weight",
219
+ "G",
220
+ "AB",
221
+ "R",
222
+ "H",
223
+ "1B",
224
+ "2B",
225
+ "3B",
226
+ "HR",
227
+ "RBI",
228
+ "SB",
229
+ "CS",
230
+ "BB",
231
+ "SO",
232
+ "IBB",
233
+ "HBP",
234
+ "SH",
235
+ "SF",
236
+ "GIDP",
237
+ "POS",
238
+ "AVG",
239
+ "OBP",
240
+ "TB",
241
+ "SLG",
242
+ "OPS",
243
+ "OPS+",
244
  ]
245
 
246
  def load(self):
 
271
 
272
  return batting
273
 
274
+ def metric(self):
275
+ return "OPS+"
276
 
277
  def vector(self, row):
278
  row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"]
 
296
  """
297
 
298
  positions = {}
299
+ for _, row in fielding.iterrows():
300
  uid = f'{row["yearID"]}{row["playerID"]}'
301
  position = row["POS"] if row["POS"] else 0
302
  if position == "P":
 
335
  uid = f'{row["yearID"]}{row["playerID"]}'
336
  return positions[uid][0] if uid in positions else 0
337
 
338
+
339
  class Pitching(Stats):
340
+ """
341
+ Pitching stats.
342
+ """
343
+
344
  def loadcolumns(self):
345
  return [
346
+ "birthMonth",
347
+ "yearID",
348
+ "age",
349
+ "height",
350
+ "weight",
351
+ "W",
352
+ "L",
353
+ "G",
354
+ "GS",
355
+ "CG",
356
+ "SHO",
357
+ "SV",
358
+ "IPouts",
359
+ "H",
360
+ "ER",
361
+ "HR",
362
+ "BB",
363
+ "SO",
364
+ "BAOpp",
365
+ "ERA",
366
+ "IBB",
367
+ "WP",
368
+ "HBP",
369
+ "BK",
370
+ "BFP",
371
+ "GF",
372
+ "R",
373
+ "SH",
374
+ "SF",
375
+ "GIDP",
376
+ "WHIP",
377
+ "WADJ",
378
  ]
379
 
380
  def load(self):
 
391
  # Calculated columns
392
  pitching["age"] = pitching["yearID"] - pitching["birthYear"]
393
  pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3)
394
+ pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"])
395
 
396
  return pitching
397
 
398
+ def metric(self):
399
+ return "WADJ"
400
 
401
  def vector(self, row):
402
  row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None
403
+ row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None
404
 
405
  return self.transform(row)
406
 
 
427
  """
428
 
429
  st.title("⚾ Baseball Statistics")
430
+ st.markdown(
431
+ """
432
  This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai).
433
  Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project.
434
+ """
435
+ )
436
 
437
  self.player()
438
+
439
  def player(self):
440
  """
441
  Player tab.
 
450
  names = sorted(stats.names)
451
  player = st.selectbox("Player", names, names.index(default))
452
 
453
+ # Player metrics
454
+ active, best, metrics = stats.metrics(player)
455
+
456
  # Player year
457
+ year = int(st.select_slider("Year", active, best) if len(active) > 1 else active[0])
458
+
459
+ # Display metrics chart
460
+ if len(active) > 1:
461
+ self.chart(category, metrics)
462
 
463
  # Run search
464
  results = stats.search(player, year)
465
 
466
  # Display results
467
+ self.table(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"])
468
+
469
+ def chart(self, category, metrics):
470
+ """
471
+ Displays a metric chart.
472
+
473
+ Args:
474
+ category: Batting or Pitching
475
+ metrics: player metrics to plot
476
+ """
477
+
478
+ # Key metric
479
+ metric = self.batting.metric() if category == "Batting" else self.pitching.metric()
480
+
481
+ # Cast year to string
482
+ metrics["yearID"] = metrics["yearID"].astype(str)
483
+
484
+ # Metric over years
485
+ chart = (
486
+ alt.Chart(metrics)
487
+ .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75)
488
+ .encode(
489
+ x=alt.X("yearID", title="").scale(padding=0),
490
+ y=alt.Y(metric).scale(zero=False, padding=0),
491
+ )
492
+ )
493
+
494
+ # Create metric median rule line
495
+ rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})")
496
+
497
+ # Layered chart configuration
498
+ chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False)
499
+
500
+ # Draw chart
501
+ st.altair_chart(chart + rule, theme="streamlit", use_container_width=True)
502
 
503
+ def table(self, results, columns):
504
  """
505
+ Displays a list of results as a table.
506
 
507
  Args:
508
  results: list of results