Train a Terrible Tic-Tac-Toe AI
This project demonstrates how to build and train a neural network to play Tic-Tac-Toe using PyTorch. The model learns optimal moves from a dataset of all possible game states and their corresponding best moves.
0. Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on {device}.")
1. Board Representation and Conversion
We represent the Tic-Tac-Toe board as a 3x3 list of lists, where each cell can be 'x', 'o', or None. To feed this into our neural network, we convert it into a tensor:
def board_to_tensor(board):
mapping = {'x': 1, 'o': -1, None: 0}
return torch.tensor([[mapping[cell] for cell in row] for row in board], dtype=torch.float32).flatten()
This function maps 'x' to 1, 'o' to -1, and empty cells to 0, then flattens the board into a 1D tensor.
2. Dataset Creation
We create a custom PyTorch Dataset to hold our game states and their corresponding best moves:
class TicTacToeDataset(Dataset):
def __init__(self, boards, moves):
self.boards = boards
self.moves = moves
def __len__(self):
return len(self.boards)
def __getitem__(self, idx):
board = self.boards[idx]
move = self.moves[idx]
return board, move
3. Neural Network Architecture
Our Tic-Tac-Toe neural network is a simple feedforward network:
class TicTacToeNN(nn.Module):
def __init__(self):
super(TicTacToeNN, self).__init__()
self.fc1 = nn.Linear(9, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 9)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return self.softmax(x)
It takes a flattened board (9 inputs) and outputs probabilities for each of the 9 possible moves.
4. Data Generation
We generate all possible valid Tic-Tac-Toe boards and their corresponding best moves using a combination of brute-force and heuristic methods. The find_best_move
function implements the game logic to determine the optimal move for any given board state.
possible_items = ["x", "o", None]
all_boards = list(list(tup) for tup in itertools.product(possible_items, repeat=9))
valid_boards = [board for board in all_boards if None in board]
boards = []
for flat_board in valid_boards:
board = [flat_board[i:i+3] for i in range(0, 9, 3)]
boards.append(board)
boards[:9]
def find_best_move(board):
def check_win(player):
for row in board:
if all(cell == player for cell in row):
return True
for col in range(3):
if all(board[row][col] == player for row in range(3)):
return True
if all(board[i][i] == player for i in range(3)) or \
all(board[i][2-i] == player for i in range(3)):
return True
return False
def count_forks(player):
forks = 0
for row in board:
if row.count(player) == 1 and row.count(None) == 2:
forks += 1
for col in range(3):
if [board[row][col] for row in range(3)].count(player) == 1 and \
[board[row][col] for row in range(3)].count(None) == 2:
forks += 1
if board[1][1] == player:
if (board[0][0] == player and board[2][2] == None) or \
(board[0][0] == None and board[2][2] == player):
forks += 1
if (board[0][2] == player and board[2][0] == None) or \
(board[0][2] == None and board[2][0] == player):
forks += 1
return forks
def board_full():
return all(cell is not None for row in board for cell in row)
def board_empty():
return all(cell is None for row in board for cell in row)
def is_valid_move(row, col):
return 0 <= row < 3 and 0 <= col < 3 and board[row][col] is None
def get_best_move():
for row in range(3):
for col in range(3):
if board[row][col] is None:
board[row][col] = 'x'
if check_win('x'):
board[row][col] = None
return (row, col)
board[row][col] = None
for row in range(3):
for col in range(3):
if board[row][col] is None:
board[row][col] = 'o'
if check_win('o'):
board[row][col] = None
return (row, col)
board[row][col] = None
for row in range(3):
for col in range(3):
if board[row][col] is None:
board[row][col] = 'x'
if count_forks('x') > 1:
board[row][col] = None
return (row, col)
board[row][col] = None
for row in range(3):
for col in range(3):
if board[row][col] is None:
board[row][col] = 'o'
if count_forks('o') > 1:
board[row][col] = None
return (row, col)
board[row][col] = None
if board[1][1] is None:
return (1, 1)
for (r, c) in [(0, 0), (0, 2), (2, 0), (2, 2)]:
if board[r][c] is None:
return (r, c)
for (r, c) in [(0, 1), (1, 0), (1, 2), (2, 1)]:
if board[r][c] is None:
return (r, c)
return None
return get_best_move()
moves = []
for board in boards:
moves.append(list(find_best_move(board)))
moves[:9]
5. Set up the Dataloader
Make sure to send it all to device
tensor_boards = [board_to_tensor(board).to(device) for board in boards]
tensor_moves = torch.tensor([move[0] * 3 + move[1] for move in moves], device=device)
dataset = TicTacToeDataset(tensor_boards, tensor_moves)
g = torch.Generator(device=device)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, generator=g)
6. Training Loop
We use the Adam optimizer and Cross-Entropy Loss to train our model:
model = TicTacToeNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 100
for epoch in range(epochs):
for boards, moves in dataloader:
boards = boards.to(device)
moves = moves.to(device)
optimizer.zero_grad()
outputs = model(boards)
loss = criterion(outputs, moves)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
7. Model Evaluation
After training, we can evaluate our model on a test board:
test_board = [[None, "o", "o"],
[None, "o", None],
[None, "x", "x"]]
test_tensor = board_to_tensor(test_board).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
prediction = model(test_tensor)
best_move_index = torch.argmax(prediction).item()
best_move = [best_move_index // 3, best_move_index % 3]
print(f"Best move for the test board: {best_move}")
This is a dumb project, and it won't work
There are many problems with this setup, especially since it will often generate illegal moves, because it doesn't check for legal moves, just the next most likely move in general.