""" Baseball statistics application with txtai and Streamlit. Install txtai and streamlit (>= 1.23) to run: pip install txtai streamlit """ import datetime import os import altair as alt import numpy as np import pandas as pd import streamlit as st from txtai.embeddings import Embeddings class Stats: """ Base stats class. Contains methods for loading, indexing and searching baseball stats. """ def __init__(self): """ Creates a new Stats instance. """ # Load columns self.columns = self.loadcolumns() # Load stats data self.stats = self.load() # Load names self.names = self.loadnames() # Build index self.vectors, self.data, self.embeddings = self.index() def loadcolumns(self): """ Returns a list of data columns. Returns: list of columns """ raise NotImplementedError def load(self): """ Loads and returns raw stats. Returns: stats """ raise NotImplementedError def metric(self): """ Primary metric column. Returns: metric column name """ raise NotImplementedError def vector(self, row): """ Build a vector for input row. Args: row: input row Returns: row vector """ raise NotImplementedError def loadnames(self): """ Loads a name - player id dictionary. Returns: {player name: player id} """ # Get unique names names = {} rows = self.stats[["nameFirst", "nameLast", "playerID"]].drop_duplicates() for _, row in rows.iterrows(): # Name key key = f"{row['nameFirst']} {row['nameLast']}" suffix = f" ({row['playerID']})" if key in names else "" # Save name key - player id pair names[f"{key}{suffix}"] = row["playerID"] return names def index(self): """ Builds an embeddings index to stats data. Returns vectors, input data and embeddings index. Returns: vectors, data, embeddings """ # Build data dictionary vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()} data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()} embeddings = Embeddings( { "transform": self.transform, } ) embeddings.index((uid, vectors[uid], None) for uid in vectors) return vectors, data, embeddings def metrics(self, player): """ Looks up a player's active years, best statistical year and key metrics. Args: player: player name Returns: active, best, metrics """ if player in self.names: # Get player stats stats = self.stats[self.stats["playerID"] == self.names[player]] # Build key metrics metrics = stats[["yearID", self.metric()]] # Get best year, sort by primary metric best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0]) # Get years active, best year, along with metric trends return metrics["yearID"].tolist(), best, metrics return range(1871, datetime.datetime.today().year), 1950, None def search(self, player=None, year=None, row=None, limit=10): """ Runs an embeddings search. This method takes either a player-year or stats row as input. Args: player: player name to search year: year to search row: row of stats to search limit: max results to return Returns: list of results """ if row: query = self.vector(row) else: # Lookup player key and build vector id query = f"{year}{self.names.get(player)}" query = self.vectors.get(query) results, ids = [], set() if query is not None: for uid, _ in self.embeddings.search(query, limit * 5): # Only add unique players if uid[4:] not in ids: result = self.data[uid].copy() result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml' result["yearID"] = str(result["yearID"]) results.append(result) ids.add(uid[4:]) if len(ids) >= limit: break return results def transform(self, row): """ Transforms a stats row into a vector. Args: row: stats row Returns: vector """ if isinstance(row, np.ndarray): return row return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns]) class Batting(Stats): """ Batting stats. """ def loadcolumns(self): return [ "birthMonth", "yearID", "age", "height", "weight", "G", "AB", "R", "H", "1B", "2B", "3B", "HR", "RBI", "SB", "CS", "BB", "SO", "IBB", "HBP", "SH", "SF", "GIDP", "POS", "AVG", "OBP", "TB", "SLG", "OPS", "OPS+", ] def load(self): # Retrieve raw data from GitHub players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv") batting = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Batting.csv") fielding = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Fielding.csv") # Merge player data in batting = pd.merge(players, batting, how="inner", on=["playerID"]) # Require player to have at least 350 plate appearances. batting = batting[(batting["AB"] + batting["BB"]) >= 350] # Derive primary player positions positions = self.positions(fielding) # Calculated columns batting["age"] = batting["yearID"] - batting["birthYear"] batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1) batting["AVG"] = batting["H"] / batting["AB"] batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"]) batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"] batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"] batting["SLG"] = batting["TB"] / batting["AB"] batting["OPS"] = batting["OBP"] + batting["SLG"] batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100 return batting def metric(self): return "OPS+" def vector(self, row): row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"] row["AVG"] = row["H"] / row["AB"] row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"]) row["SLG"] = row["TB"] / row["AB"] row["OPS"] = row["OBP"] + row["SLG"] row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100 return self.transform(row) def positions(self, fielding): """ Derives primary positions for players. Args: fielding: fielding data Returns: {player id: (position, number of games)} """ positions = {} for _, row in fielding.iterrows(): uid = f'{row["yearID"]}{row["playerID"]}' position = row["POS"] if row["POS"] else 0 if position == "P": position = 1 elif position == "C": position = 2 elif position == "1B": position = 3 elif position == "2B": position = 4 elif position == "3B": position = 5 elif position == "SS": position = 6 elif position == "OF": position = 7 # Save position if not set or player played more at this position if uid not in positions or positions[uid][1] < row["G"]: positions[uid] = (position, row["G"]) return positions def position(self, positions, row): """ Looks up primary position for player row. Arg: positions: all player positions row: player row Returns: primary player positions """ uid = f'{row["yearID"]}{row["playerID"]}' return positions[uid][0] if uid in positions else 0 class Pitching(Stats): """ Pitching stats. """ def loadcolumns(self): return [ "birthMonth", "yearID", "age", "height", "weight", "W", "L", "G", "GS", "CG", "SHO", "SV", "IPouts", "H", "ER", "HR", "BB", "SO", "BAOpp", "ERA", "IBB", "WP", "HBP", "BK", "BFP", "GF", "R", "SH", "SF", "GIDP", "WHIP", "WADJ", ] def load(self): # Retrieve raw data from GitHub players = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/People.csv") pitching = pd.read_csv("https://raw.githubusercontent.com/chadwickbureau/baseballdatabank/master/core/Pitching.csv") # Merge player data in pitching = pd.merge(players, pitching, how="inner", on=["playerID"]) # Require player to have 20 appearances pitching = pitching[pitching["G"] >= 20] # Calculated columns pitching["age"] = pitching["yearID"] - pitching["birthYear"] pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3) pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"]) return pitching def metric(self): return "WADJ" def vector(self, row): row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None return self.transform(row) class Application: """ Main application. """ def __init__(self): """ Creates a new application. """ # Batting stats self.batting = Batting() # Pitching stats self.pitching = Pitching() def run(self): """ Runs a Streamlit application. """ st.title("⚾ Baseball Statistics") st.markdown( """ This application finds the best matching historical players using vector search with [txtai](https://github.com/neuml/txtai). Raw data is from the [Baseball Databank](https://github.com/chadwickbureau/baseballdatabank) GitHub project. """ ) self.player() def player(self): """ Player tab. """ st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.") category = st.radio("Stat", ["Batting", "Pitching"], horizontal=True, key="playerstat") stats, default = (self.batting, "Babe Ruth") if category == "Batting" else (self.pitching, "Cy Young") # Player name names = sorted(stats.names) player = st.selectbox("Player", names, names.index(default)) # Player metrics active, best, metrics = stats.metrics(player) # Player year year = int(st.select_slider("Year", active, best) if len(active) > 1 else active[0]) # Display metrics chart if len(active) > 1: self.chart(category, metrics) # Run search results = stats.search(player, year) # Display results self.table(results, ["nameFirst", "nameLast", "teamID"] + stats.columns[1:] + ["link"]) def chart(self, category, metrics): """ Displays a metric chart. Args: category: Batting or Pitching metrics: player metrics to plot """ # Key metric metric = self.batting.metric() if category == "Batting" else self.pitching.metric() # Cast year to string metrics["yearID"] = metrics["yearID"].astype(str) # Metric over years chart = ( alt.Chart(metrics) .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75) .encode( x=alt.X("yearID", title="").scale(padding=0), y=alt.Y(metric).scale(zero=False, padding=0), ) ) # Create metric median rule line rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})") # Layered chart configuration chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False) # Draw chart st.altair_chart(chart + rule, theme="streamlit", use_container_width=True) def table(self, results, columns): """ Displays a list of results as a table. Args: results: list of results columns: column names """ if results: st.dataframe(pd.DataFrame(results)[columns]) else: st.write("Player-Year not found") @st.cache_resource(show_spinner=False) def create(): """ Creates and caches a Streamlit application. Returns: Application """ return Application() if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" # Create and run application app = create() app.run()