"""
FileName: app.py
Author: Benhao Huang
Create Date: 2023/11/19
Description: this file is used to display our project and add visualization elements to the game, using Streamlit
"""
import time
import pandas as pd
from copy import deepcopy
import torch
# import torch
import numpy as np
import streamlit as st
from scipy.signal import convolve # this is used to check if any player wins
from streamlit import session_state
from streamlit_server_state import server_state, server_state_lock
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
from Gomoku_Bot import Gomoku_bot
from Gomoku_Bot import Board as Gomoku_bot_board
import matplotlib.pyplot as plt
from const import (
_BLACK, # 1, for human
_WHITE, # 2 , for AI
_BLANK,
_PLAYER_COLOR,
_PLAYER_SYMBOL,
_ROOM_COLOR,
_VERTICAL,
_NEW,
_HORIZONTAL,
_DIAGONAL_UP_LEFT,
_DIAGONAL_UP_RIGHT,
_BOARD_SIZE,
_BOARD_SIZE_1D,
_AI_AID_INFO
)
from ai import (
BOS_TOKEN_ID,
generate_gpt2,
load_model,
)
gpt2 = load_model()
# Utils
class Room:
def __init__(self, room_id) -> None:
self.ROOM_ID = room_id
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=[_BLACK, _WHITE])
self.PLAYER = _BLACK
self.TURN = self.PLAYER
self.HISTORY = (0, 0)
self.WINNER = _BLANK
self.TIME = time.time()
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
self.MCTS = self.MCTS_dict['AlphaZero']
self.last_mcts = self.MCTS
self.AID_MCTS = self.MCTS_dict['AlphaZero']
self.COORDINATE_1D = [BOS_TOKEN_ID]
self.current_move = -1
self.simula_time_list = []
def change_turn(cur):
return cur % 2 + 1
# Initialize the game
if "ROOM" not in session_state:
session_state.ROOM = Room("local")
if "OWNER" not in session_state:
session_state.OWNER = False
if "USE_AIAID" not in session_state:
session_state.USE_AIAID = False
# Check server health
if "ROOMS" not in server_state:
with server_state_lock["ROOMS"]:
server_state.ROOMS = {}
def handle_oppo_model_selection():
if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
return
else:
TreeNode = session_state.ROOM.last_mcts.mcts._root
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
new_mct.mcts._root = deepcopy(TreeNode)
session_state.ROOM.MCTS = new_mct
session_state.ROOM.last_mcts = new_mct
return
def handle_aid_model_selection():
if st.session_state['selected_aid_model'] == 'None':
session_state.USE_AIAID = False
return
session_state.USE_AIAID = True
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
new_mct.mcts._root = deepcopy(TreeNode)
session_state.ROOM.AID_MCTS = new_mct
return
if 'selected_oppo_model' not in st.session_state:
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
if 'selected_aid_model' not in st.session_state:
st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值
# Layout
TITLE = st.empty()
Model_Switch = st.empty()
TITLE.header("🤖 AI 3603 Gomoku")
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero','Gomoku Bot'], index=1, key='oppo_model')
if st.session_state['selected_oppo_model'] != selected_oppo_option:
st.session_state['selected_oppo_model'] = selected_oppo_option
handle_oppo_model_selection()
ROUND_INFO = st.empty()
st.markdown("
", unsafe_allow_html=True)
BOARD_PLATE = [
[cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE)
]
LOG = st.empty()
# Sidebar
SCORE_TAG = st.sidebar.empty()
SCORE_PLATE = st.sidebar.columns(2)
# History scores
SCORE_TAG.subheader("Scores")
PLAY_MODE_INFO = st.sidebar.container()
MULTIPLAYER_TAG = st.sidebar.empty()
with st.sidebar.container():
ANOTHER_ROUND = st.empty()
RESTART = st.empty()
AIAID = st.empty()
EXIT = st.empty()
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0, key='aid_model')
if st.session_state['selected_aid_model'] != selected_aid_option:
st.session_state['selected_aid_model'] = selected_aid_option
handle_aid_model_selection()
GAME_INFO = st.sidebar.container()
message = st.empty()
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
GAME_INFO.markdown(
"""
---
# Freestyle Gomoku game. 🎲
- no restrictions 🚫
- no regrets 😎
- no regrets 😎
- swap players after one round is over 🔁
Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our report.
##### Adapted and improved by us! 🌟 Our Github repo
""",
unsafe_allow_html=True,
)
def restart() -> None:
"""
Restart the game.
"""
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
st.session_state['selected_oppo_model'] = 'AlphaZero'
RESTART.button(
"Reset",
on_click=restart,
help="Clear the board as well as the scores",
)
# Draw the board
def gomoku():
"""
Draw the board.
Handle the main logic.
"""
# Restart the game
# Continue new round
def another_round() -> None:
"""
Continue new round.
"""
session_state.ROOM = deepcopy(session_state.ROOM)
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
session_state.ROOM.TURN = session_state.ROOM.PLAYER
session_state.ROOM.WINNER = _BLANK # 0
session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
# Room status sync
def sync_room() -> bool:
room_id = session_state.ROOM.ROOM_ID
if room_id not in server_state.ROOMS.keys():
session_state.ROOM = Room("local")
return False
elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
return False
elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
# Only acquire the lock when writing to the server state
with server_state_lock["ROOMS"]:
server_rooms = server_state.ROOMS
server_rooms[room_id] = session_state.ROOM
server_state.ROOMS = server_rooms
return True
else:
session_state.ROOM = server_state.ROOMS[room_id]
return True
# Check if winner emerge from move
def check_win() -> int:
"""
Use convolution to check if any player wins.
"""
vertical = convolve(
session_state.ROOM.BOARD.board_map,
_VERTICAL,
mode="same",
)
horizontal = convolve(
session_state.ROOM.BOARD.board_map,
_HORIZONTAL,
mode="same",
)
diagonal_up_left = convolve(
session_state.ROOM.BOARD.board_map,
_DIAGONAL_UP_LEFT,
mode="same",
)
diagonal_up_right = convolve(
session_state.ROOM.BOARD.board_map,
_DIAGONAL_UP_RIGHT,
mode="same",
)
if (
np.max(
[
np.max(vertical),
np.max(horizontal),
np.max(diagonal_up_left),
np.max(diagonal_up_right),
]
)
== 5 * _BLACK
):
winner = _BLACK
elif (
np.min(
[
np.min(vertical),
np.min(horizontal),
np.min(diagonal_up_left),
np.min(diagonal_up_right),
]
)
== 5 * _WHITE
):
winner = _WHITE
else:
winner = _BLANK
return winner
# Triggers the board response on click
def handle_click(x, y):
"""
Controls whether to pass on / continue current board / may start new round
"""
if session_state.ROOM.BOARD.board_map[x][y] != _BLANK:
pass
elif (
session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
and _ROOM_COLOR[session_state.OWNER]
!= server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
):
sync_room()
# normal play situation
elif session_state.ROOM.WINNER == _BLANK:
# session_state.ROOM = deepcopy(session_state.ROOM)
# print("View of human player: ", session_state.ROOM.BOARD.board_map)
move = session_state.ROOM.BOARD.location_to_move((x, y))
session_state.ROOM.current_move = move
session_state.ROOM.BOARD.do_move(move)
# Gomoku Bot BOARD
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
win, winner = session_state.ROOM.BOARD.game_end()
if win:
session_state.ROOM.WINNER = winner
session_state.ROOM.HISTORY = (
session_state.ROOM.HISTORY[0]
+ int(session_state.ROOM.WINNER == _WHITE),
session_state.ROOM.HISTORY[1]
+ int(session_state.ROOM.WINNER == _BLACK),
)
session_state.ROOM.TIME = time.time()
def forbid_click(x, y):
# st.warning('This posistion has been occupied!!!!', icon="⚠️")
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
# Draw board
def draw_board(response: bool):
"""construct each buttons for all cells of the board"""
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.TURN == _BLACK:
if session_state.USE_AIAID:
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
if response and session_state.ROOM.TURN == _BLACK: # human turn
print("Your turn")
# construction of clickable buttons
cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE)
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
# print("row:", row)
for j, cell in enumerate(row):
if (
i * _BOARD_SIZE + j
in (session_state.ROOM.COORDINATE_1D)
):
if i == cur_move[0] and j == cur_move[1]:
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[_NEW],
key=f"{i}:{j}",
args=(i, j),
on_click=handle_click,
)
else:
# disable click for GPT choices
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[cell],
key=f"{i}:{j}",
args=(i, j),
on_click=forbid_click
)
else:
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
# enable click for other cells available for human choices
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
key=f"{i}:{j}",
on_click=handle_click,
args=(i, j),
)
else:
# enable click for other cells available for human choices
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[cell],
key=f"{i}:{j}",
on_click=handle_click,
args=(i, j),
)
elif response and session_state.ROOM.TURN == _WHITE: # AI turn
message.empty()
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
time.sleep(0.1)
print("AI's turn")
print("Below are current board under AI's view")
# print(session_state.ROOM.BOARD.board_map)
# move = _BOARD_SIZE * _BOARD_SIZE
# forbid = []
# step = 0.1
# tmp = 0.7
# while move >= _BOARD_SIZE * _BOARD_SIZE or move in session_state.ROOM.COORDINATE_1D:
#
# gpt_predictions = generate_gpt2(
# gpt2,
# torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0),
# tmp
# )
# print(gpt_predictions)
# move = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)]
# print(move)
# tmp += step
# # if move >= _BOARD_SIZE * _BOARD_SIZE:
# # forbid.append(move)
# # else:
# # break
#
#
# gpt_response = move
# gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
# print(gpt_i, gpt_j)
# # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
#
# simul_time = 0
if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
else:
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
session_state.ROOM.simula_time_list.append(simul_time)
print("AI takes move: ", move)
session_state.ROOM.current_move = move
gpt_response = move
gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
print("Location to move: ", move)
# print("Location to move: ", move)
# MCTS BOARD
session_state.ROOM.BOARD.do_move(move)
# Gomoku Bot BOARD
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
if not session_state.ROOM.BOARD.game_end()[0]:
if session_state.USE_AIAID:
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
else:
top_five_acts = []
top_five_probs = []
# construction of clickable buttons
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
# print("row:", row)
for j, cell in enumerate(row):
if (
i * _BOARD_SIZE + j
in (session_state.ROOM.COORDINATE_1D)
):
if i == gpt_i and j == gpt_j:
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[_NEW],
key=f"{i}:{j}",
args=(i, j),
on_click=handle_click,
)
else:
# disable click for GPT choices
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[cell],
key=f"{i}:{j}",
args=(i, j),
on_click=forbid_click
)
else:
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
# enable click for other cells available for human choices
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
key=f"{i}:{j}",
on_click=handle_click,
args=(i, j),
)
else:
# enable click for other cells available for human choices
BOARD_PLATE[i][j].button(
_PLAYER_SYMBOL[cell],
key=f"{i}:{j}",
on_click=handle_click,
args=(i, j),
)
message.markdown(
'AI agent has calculated its strategy, which takes {:.3e}s per simulation.'.format(
simul_time),
unsafe_allow_html=True
)
LOG.subheader("Logs")
# change turn
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
# session_state.ROOM.WINNER = check_win()
win, winner = session_state.ROOM.BOARD.game_end()
if win:
session_state.ROOM.WINNER = winner
session_state.ROOM.HISTORY = (
session_state.ROOM.HISTORY[0]
+ int(session_state.ROOM.WINNER == _WHITE),
session_state.ROOM.HISTORY[1]
+ int(session_state.ROOM.WINNER == _BLACK),
)
session_state.ROOM.TIME = time.time()
if not response or session_state.ROOM.WINNER != _BLANK:
if session_state.ROOM.WINNER != _BLANK:
print("Game over")
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
for j, cell in enumerate(row):
BOARD_PLATE[i][j].write(
_PLAYER_SYMBOL[cell],
# key=f"{i}:{j}",
)
# Game process control
def game_control():
if session_state.ROOM.WINNER != _BLANK:
draw_board(False)
else:
draw_board(True)
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
ANOTHER_ROUND.button(
"Play Next round!",
on_click=another_round,
help="Clear board and swap first player",
)
# Infos
def update_info() -> None:
# Additional information
SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0])
SCORE_PLATE[1].metric("Black", session_state.ROOM.HISTORY[1])
if session_state.ROOM.WINNER != _BLANK:
st.balloons()
ROUND_INFO.write(
f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
)
# elif 0 not in session_state.ROOM.BOARD.board_map:
# ROUND_INFO.write("#### **Tie**")
# else:
# ROUND_INFO.write(
# f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**"
# )
# draw the plot for simulation time
# 创建一个 DataFrame
# print(session_state.ROOM.simula_time_list)
st.markdown("
", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
st.line_chart(chart_data)
game_control()
update_info()
if __name__ == "__main__":
gomoku()