|
""" |
|
Baseball statistics application with txtai and Streamlit. |
|
|
|
Install txtai and streamlit (>= 1.23) to run: |
|
pip install txtai streamlit |
|
""" |
|
|
|
import datetime |
|
import os |
|
|
|
import altair as alt |
|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
|
|
from txtai.embeddings import Embeddings |
|
|
|
|
|
class Stats: |
|
""" |
|
Base stats class. Contains methods for loading, indexing and searching baseball stats. |
|
""" |
|
|
|
def __init__(self): |
|
""" |
|
Creates a new Stats instance. |
|
""" |
|
|
|
|
|
self.columns = self.loadcolumns() |
|
|
|
|
|
self.stats = self.load() |
|
|
|
|
|
self.names = self.loadnames() |
|
|
|
|
|
self.vectors, self.data, self.embeddings = self.index() |
|
|
|
def loadcolumns(self): |
|
""" |
|
Returns a list of data columns. |
|
|
|
Returns: |
|
list of columns |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
def load(self): |
|
""" |
|
Loads and returns raw stats. |
|
|
|
Returns: |
|
stats |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
def metric(self): |
|
""" |
|
Primary metric column. |
|
|
|
Returns: |
|
metric column name |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
def vector(self, row): |
|
""" |
|
Build a vector for input row. |
|
|
|
Args: |
|
row: input row |
|
|
|
Returns: |
|
row vector |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
def loadnames(self): |
|
""" |
|
Loads a name - player id dictionary. |
|
|
|
Returns: |
|
{player name: player id} |
|
""" |
|
|
|
|
|
names = {} |
|
rows = self.stats[["nameFirst", "nameLast", "playerID"]].drop_duplicates() |
|
for _, row in rows.iterrows(): |
|
|
|
key = f"{row['nameFirst']} {row['nameLast']}" |
|
suffix = f" ({row['playerID']})" if key in names else "" |
|
|
|
|
|
names[f"{key}{suffix}"] = row["playerID"] |
|
|
|
return names |
|
|
|
def index(self): |
|
""" |
|
Builds an embeddings index to stats data. Returns vectors, input data and embeddings index. |
|
|
|
Returns: |
|
vectors, data, embeddings |
|
""" |
|
|
|
|
|
vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()} |
|
data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()} |
|
|
|
embeddings = Embeddings( |
|
{ |
|
"transform": self.transform, |
|
} |
|
) |
|
|
|
embeddings.index((uid, vectors[uid], None) for uid in vectors) |
|
|
|
return vectors, data, embeddings |
|
|
|
def metrics(self, player): |
|
""" |
|
Looks up a player's active years, best statistical year and key metrics. |
|
|
|
Args: |
|
player: player name |
|
|
|
Returns: |
|
active, best, metrics |
|
""" |
|
|
|
if player in self.names: |
|
|
|
stats = self.stats[self.stats["playerID"] == self.names[player]] |
|
|
|
|
|
metrics = stats[["yearID", self.metric()]] |
|
|
|
|
|
best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0]) |
|
|
|
|
|
return metrics["yearID"].tolist(), best, metrics |
|
|
|
return range(1871, datetime.datetime.today().year), 1950, None |
|
|
|
def search(self, player=None, year=None, row=None, limit=10): |
|
""" |
|
Runs an embeddings search. This method takes either a player-year or stats row as input. |
|
|
|
Args: |
|
player: player name to search |
|
year: year to search |
|
row: row of stats to search |
|
limit: max results to return |
|
|
|
Returns: |
|
list of results |
|
""" |
|
|
|
if row: |
|
query = self.vector(row) |
|
else: |
|
|
|
query = f"{year}{self.names.get(player)}" |
|
query = self.vectors.get(query) |
|
|
|
results, ids = [], set() |
|
if query is not None: |
|
for uid, _ in self.embeddings.search(query, limit * 5): |
|
|
|
if uid[4:] not in ids: |
|
result = self.data[uid].copy() |
|
result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml' |
|
result["yearID"] = str(result["yearID"]) |
|
results.append(result) |
|
ids.add(uid[4:]) |
|
|
|
if len(ids) >= limit: |
|
break |
|
|
|
return results |
|
|
|
def transform(self, row): |
|
""" |
|
Transforms a stats row into a vector. |
|
|
|
Args: |
|
row: stats row |
|
|
|
Returns: |
|
vector |
|
""" |
|
|
|
if isinstance(row, np.ndarray): |
|
return row |
|
|
|
return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns]) |
|
|
|
|
|
class Batting(Stats): |
|
""" |
|
Batting stats. |
|
""" |
|
|
|
def loadcolumns(self): |
|
return [ |
|
"birthMonth", |
|
"yearID", |
|
"age", |
|
"height", |
|
"weight", |
|
"G", |
|
"AB", |
|
"R", |
|
"H", |
|
"1B", |
|
"2B", |
|
"3B", |
|
"HR", |
|
"RBI", |
|
"SB", |
|
"CS", |
|
"BB", |
|
"SO", |
|
"IBB", |
|
"HBP", |
|
"SH", |
|
"SF", |
|
"GIDP", |
|
"POS", |
|
"AVG", |
|
"OBP", |
|
"TB", |
|
"SLG", |
|
"OPS", |
|
"OPS+", |
|
] |
|
|
|
def load(self): |
|
|
|
players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv") |
|
batting = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Batting.csv") |
|
fielding = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Fielding.csv") |
|
|
|
|
|
batting = pd.merge(players, batting, how="inner", on=["playerID"]) |
|
|
|
|
|
batting = batting[(batting["AB"] + batting["BB"]) >= 350] |
|
|
|
|
|
positions = self.positions(fielding) |
|
|
|
|
|
batting["age"] = batting["yearID"] - batting["birthYear"] |
|
batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1) |
|
batting["AVG"] = batting["H"] / batting["AB"] |
|
batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"]) |
|
batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"] |
|
batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"] |
|
batting["SLG"] = batting["TB"] / batting["AB"] |
|
batting["OPS"] = batting["OBP"] + batting["SLG"] |
|
batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100 |
|
|
|
return batting |
|
|
|
def metric(self): |
|
return "OPS+" |
|
|
|
def vector(self, row): |
|
row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"] |
|
row["AVG"] = row["H"] / row["AB"] |
|
row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"]) |
|
row["SLG"] = row["TB"] / row["AB"] |
|
row["OPS"] = row["OBP"] + row["SLG"] |
|
row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100 |
|
|
|
return self.transform(row) |
|
|
|
def positions(self, fielding): |
|
""" |
|
Derives primary positions for players. |
|
|
|
Args: |
|
fielding: fielding data |
|
|
|
Returns: |
|
{player id: (position, number of games)} |
|
""" |
|
|
|
positions = {} |
|
for _, row in fielding.iterrows(): |
|
uid = f'{row["yearID"]}{row["playerID"]}' |
|
position = row["POS"] if row["POS"] else 0 |
|
if position == "P": |
|
position = 1 |
|
elif position == "C": |
|
position = 2 |
|
elif position == "1B": |
|
position = 3 |
|
elif position == "2B": |
|
position = 4 |
|
elif position == "3B": |
|
position = 5 |
|
elif position == "SS": |
|
position = 6 |
|
elif position == "OF": |
|
position = 7 |
|
|
|
|
|
if uid not in positions or positions[uid][1] < row["G"]: |
|
positions[uid] = (position, row["G"]) |
|
|
|
return positions |
|
|
|
def position(self, positions, row): |
|
""" |
|
Looks up primary position for player row. |
|
|
|
Arg: |
|
positions: all player positions |
|
row: player row |
|
|
|
Returns: |
|
primary player positions |
|
""" |
|
|
|
uid = f'{row["yearID"]}{row["playerID"]}' |
|
return positions[uid][0] if uid in positions else 0 |
|
|
|
|
|
class Pitching(Stats): |
|
""" |
|
Pitching stats. |
|
""" |
|
|
|
def loadcolumns(self): |
|
return [ |
|
"birthMonth", |
|
"yearID", |
|
"age", |
|
"height", |
|
"weight", |
|
"W", |
|
"L", |
|
"G", |
|
"GS", |
|
"CG", |
|
"SHO", |
|
"SV", |
|
"IPouts", |
|
"H", |
|
"ER", |
|
"HR", |
|
"BB", |
|
"SO", |
|
"BAOpp", |
|
"ERA", |
|
"IBB", |
|
"WP", |
|
"HBP", |
|
"BK", |
|
"BFP", |
|
"GF", |
|
"R", |
|
"SH", |
|
"SF", |
|
"GIDP", |
|
"WHIP", |
|
"WADJ", |
|
] |
|
|
|
def load(self): |
|
|
|
players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv") |
|
pitching = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Pitching.csv") |
|
|
|
|
|
pitching = pd.merge(players, pitching, how="inner", on=["playerID"]) |
|
|
|
|
|
pitching = pitching[pitching["G"] >= 20] |
|
|
|
|
|
pitching["age"] = pitching["yearID"] - pitching["birthYear"] |
|
pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3) |
|
pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"]) |
|
|
|
return pitching |
|
|
|
def metric(self): |
|
return "WADJ" |
|
|
|
def vector(self, row): |
|
row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None |
|
row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None |
|
|
|
return self.transform(row) |
|
|
|
|
|
class Application: |
|
""" |
|
Main application. |
|
""" |
|
|
|
def __init__(self): |
|
""" |
|
Creates a new application. |
|
""" |
|
|
|
|
|
self.batting = Batting() |
|
|
|
|
|
self.pitching = Pitching() |
|
|
|
def run(self): |
|
""" |
|
Runs a Streamlit application. |
|
""" |
|
|
|
st.title("⚾ Baseball Statistics") |
|
st.markdown( |
|
""" |
|
This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai). |
|
Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project. |
|
""" |
|
) |
|
|
|
self.player() |
|
|
|
def player(self): |
|
""" |
|
Player tab. |
|
""" |
|
|
|
st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.") |
|
|
|
category = st.radio("Stat", ["Batting", "Pitching"], horizontal=True, key="playerstat") |
|
stats, default = (self.batting, "Babe Ruth") if category == "Batting" else (self.pitching, "Cy Young") |
|
|
|
|
|
names = sorted(stats.names) |
|
player = st.selectbox("Player", names, names.index(default)) |
|
|
|
|
|
active, best, metrics = stats.metrics(player) |
|
|
|
|
|
year = int(st.select_slider("Year", active, best) if len(active) > 1 else active[0]) |
|
|
|
|
|
if len(active) > 1: |
|
self.chart(category, metrics) |
|
|
|
|
|
results = stats.search(player, year) |
|
|
|
|
|
self.table(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"]) |
|
|
|
def chart(self, category, metrics): |
|
""" |
|
Displays a metric chart. |
|
|
|
Args: |
|
category: Batting or Pitching |
|
metrics: player metrics to plot |
|
""" |
|
|
|
|
|
metric = self.batting.metric() if category == "Batting" else self.pitching.metric() |
|
|
|
|
|
metrics["yearID"] = metrics["yearID"].astype(str) |
|
|
|
|
|
chart = ( |
|
alt.Chart(metrics) |
|
.mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75) |
|
.encode( |
|
x=alt.X("yearID", title="").scale(padding=0), |
|
y=alt.Y(metric).scale(zero=False, padding=0), |
|
) |
|
) |
|
|
|
|
|
rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})") |
|
|
|
|
|
chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False) |
|
|
|
|
|
st.altair_chart(chart + rule, theme="streamlit", use_container_width=True) |
|
|
|
def table(self, results, columns): |
|
""" |
|
Displays a list of results as a table. |
|
|
|
Args: |
|
results: list of results |
|
columns: column names |
|
""" |
|
|
|
if results: |
|
st.dataframe(pd.DataFrame(results)[columns]) |
|
else: |
|
st.write("Player-Year not found") |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def create(): |
|
""" |
|
Creates and caches a Streamlit application. |
|
|
|
Returns: |
|
Application |
|
""" |
|
|
|
return Application() |
|
|
|
|
|
if __name__ == "__main__": |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
app = create() |
|
app.run() |
|
|