Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
Monte Carlo Tree Search in AlphaGo Zero style, which uses a policy-value | |
network to guide the tree search and evaluate the leaf nodes | |
@author: Junxiao Song | |
""" | |
import numpy as np | |
import copy | |
import time | |
from concurrent.futures import ThreadPoolExecutor | |
import threading | |
def softmax(x): | |
probs = np.exp(x - np.max(x)) | |
probs /= np.sum(probs) | |
return probs | |
class TreeNode(object): | |
"""A node in the MCTS tree. | |
Each node keeps track of its own value Q, prior probability P, and | |
its visit-count-adjusted prior score u. | |
""" | |
def __init__(self, parent, prior_p): | |
self._parent = parent | |
self._children = {} # a map from action to TreeNode | |
self._n_visits = 0 | |
self._Q = 0 | |
self._u = 0 | |
self._P = prior_p | |
def expand(self, action_priors): | |
"""Expand tree by creating new children. | |
action_priors: a list of tuples of actions and their prior probability | |
according to the policy function. | |
""" | |
for action, prob in action_priors: | |
if action not in self._children: | |
self._children[action] = TreeNode(self, prob) | |
def select(self, c_puct): | |
"""Select action among children that gives maximum action value Q | |
plus bonus u(P). | |
Return: A tuple of (action, next_node) | |
""" | |
return max(self._children.items(), | |
key=lambda act_node: act_node[1].get_value(c_puct)) | |
def update(self, leaf_value): | |
"""Update node values from leaf evaluation. | |
leaf_value: the value of subtree evaluation from the current player's | |
perspective. | |
""" | |
# Count visit. | |
self._n_visits += 1 | |
# Update Q, a running average of values for all visits. | |
self._Q += 1.0*(leaf_value - self._Q) / self._n_visits | |
def update_recursive(self, leaf_value): | |
"""Like a call to update(), but applied recursively for all ancestors. | |
""" | |
# If it is not root, this node's parent should be updated first. | |
if self._parent: | |
self._parent.update_recursive(-leaf_value) | |
self.update(leaf_value) | |
def get_value(self, c_puct): | |
"""Calculate and return the value for this node. | |
It is a combination of leaf evaluations Q, and this node's prior | |
adjusted for its visit count, u. | |
c_puct: a number in (0, inf) controlling the relative impact of | |
value Q, and prior probability P, on this node's score. | |
""" | |
self._u = (c_puct * self._P * | |
np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) | |
return self._Q + self._u | |
def is_leaf(self): | |
"""Check if leaf node (i.e. no nodes below this have been expanded).""" | |
return self._children == {} | |
def is_root(self): | |
return self._parent is None | |
class MCTS(object): | |
"""An implementation of Monte Carlo Tree Search.""" | |
def __init__(self, policy_value_fn, c_puct=5, n_playout=10000): | |
""" | |
policy_value_fn: a function that takes in a board state and outputs | |
a list of (action, probability) tuples and also a score in [-1, 1] | |
(i.e. the expected value of the end game score from the current | |
player's perspective) for the current player. | |
c_puct: a number in (0, inf) that controls how quickly exploration | |
converges to the maximum-value policy. A higher value means | |
relying on the prior more. | |
""" | |
self._root = TreeNode(None, 1.0) | |
self._policy = policy_value_fn | |
self._c_puct = c_puct | |
self._n_playout = n_playout | |
def _playout(self, state, lock=None): | |
"""Run a single playout from the root to the leaf, getting a value at | |
the leaf and propagating it back through its parents. | |
State is modified in-place, so a copy must be provided. | |
""" | |
node = self._root | |
if lock is not None: | |
lock.acquire() | |
while(1): | |
if node.is_leaf(): | |
break | |
# Greedily select next move. | |
action, node = node.select(self._c_puct) | |
state.do_move(action) | |
if lock is not None: | |
lock.release() | |
# Evaluate the leaf using a network which outputs a list of | |
# (action, probability) tuples p and also a score v in [-1, 1] | |
# for the current player. | |
action_probs, leaf_value = self._policy(state) | |
# Check for end of game. | |
end, winner = state.game_end() | |
if lock is not None: | |
lock.acquire() | |
if not end: | |
node.expand(action_probs) | |
else: | |
# for end state,return the "true" leaf_value | |
if winner == -1: # tie | |
leaf_value = 0.0 | |
else: | |
leaf_value = ( | |
1.0 if winner == state.get_current_player() else -1.0 | |
) | |
# Update value and visit count of nodes in this traversal. | |
node.update_recursive(-leaf_value) | |
if lock is not None: | |
lock.release() | |
def get_move_probs(self, state, temp=1e-3): | |
"""Run all playouts sequentially and return the available actions and | |
their corresponding probabilities. | |
state: the current game state | |
temp: temperature parameter in (0, 1] controls the level of exploration | |
""" | |
start_time_averge = 0 | |
### test multi-thread | |
# lock = threading.Lock() | |
# with ThreadPoolExecutor(max_workers=4) as executor: | |
# for n in range(self._n_playout): | |
# start_time = time.time() | |
# state_copy = copy.deepcopy(state) | |
# executor.submit(self._playout, state_copy, lock) | |
# start_time_averge += (time.time() - start_time) | |
### end test multi-thread | |
t = time.time() | |
for n in range(self._n_playout): | |
start_time = time.time() | |
state_copy = copy.deepcopy(state) | |
self._playout(state_copy) | |
start_time_averge += (time.time() - start_time) | |
total_time = time.time() - t | |
# print('!!time!!:', time.time() - t) | |
print(f" My MCTS sum_time: {total_time }, total_simulation: {self._n_playout}") | |
# calc the move probabilities based on visit counts at the root node | |
act_visits = [(act, node._n_visits) | |
for act, node in self._root._children.items()] | |
acts, visits = zip(*act_visits) | |
act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10)) | |
return 0, acts, act_probs, total_time | |
def update_with_move(self, last_move): | |
"""Step forward in the tree, keeping everything we already know | |
about the subtree. | |
""" | |
if last_move in self._root._children: | |
self._root = self._root._children[last_move] | |
self._root._parent = None | |
else: | |
self._root = TreeNode(None, 1.0) | |
def __str__(self): | |
return "MCTS" | |
class MCTSPlayer(object): | |
"""AI player based on MCTS""" | |
def __init__(self, policy_value_function, | |
c_puct=5, n_playout=2000, is_selfplay=0): | |
self.mcts = MCTS(policy_value_function, c_puct, n_playout) | |
self._is_selfplay = is_selfplay | |
def set_player_ind(self, p): | |
self.player = p | |
def reset_player(self): | |
self.mcts.update_with_move(-1) | |
def get_action(self, board, temp=1e-3, return_prob=0,return_time = False): | |
sensible_moves = board.availables | |
# the pi vector returned by MCTS as in the alphaGo Zero paper | |
move_probs = np.zeros(board.width*board.height) | |
if len(sensible_moves) > 0: | |
_, acts, probs, simul_mean_time = self.mcts.get_move_probs(board, temp) | |
move_probs[list(acts)] = probs | |
if self._is_selfplay: | |
# add Dirichlet Noise for exploration (needed for | |
# self-play training) | |
move = np.random.choice( | |
acts, | |
p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs))) | |
) | |
# update the root node and reuse the search tree | |
self.mcts.update_with_move(move) | |
else: | |
# with the default temp=1e-3, it is almost equivalent | |
# to choosing the move with the highest prob | |
move = np.random.choice(acts, p=probs) | |
# reset the root node | |
self.mcts.update_with_move(-1) | |
# location = board.move_to_location(move) | |
# print("AI move: %d,%d\n" % (location[0], location[1])) | |
if return_time: | |
if return_prob: | |
return move, move_probs,simul_mean_time | |
else: | |
return move,simul_mean_time | |
else: | |
if return_prob: | |
return move, move_probs | |
else: | |
return move | |
else: | |
print("WARNING: the board is full") | |
def __str__(self): | |
return "MCTS {}".format(self.player) | |