Spaces:
Sleeping
Sleeping
File size: 13,339 Bytes
9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 beb9e09 9cefce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 |
"""
FileName: mcts_Gumbel_Alphazero.py
Author: Jiaxin Li
Create Date: 2023/11/21
Description: The implement of Gumbel MCST
Edit History:
Debug: the dim of output: probs
"""
import numpy as np
import copy
import time
from .config.options import *
import sys
from .config.utils import *
def softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
def _sigma_mano(y, Nb):
return (50 + Nb) * 1.0 * y
class TreeNode(object):
"""A node in the MCTS tree.
Each node keeps track of its own value Q, prior probability P, and
its visit-count-adjusted prior score u.
"""
def __init__(self, parent, prior_p):
self._parent = parent
self._children = {} # a map from action to TreeNode
self._n_visits = 0
self._Q = 0
self._u = 0
self._v = 0
self._p = prior_p
def expand(self, action_priors):
"""Expand tree by creating new children.
action_priors: a list of tuples of actions and their prior probability
according to the policy function.
"""
for action, prob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(self, prob)
def select(self, v_pi):
"""Select action among children that gives maximum
(pi'(a) - N(a) \ (1 + \sum_b N(b)))
Return: A tuple of (action, next_node)
"""
# if opts.split == "train":
# v_pi = v_pi.detach().numpy()
# print(v_pi)
max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._children.items()]))
if opts.split == "train":
pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
len(list(self._children.items())), -1)
else:
pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
len(list(self._children.items())), -1)
# print(pi_.shape)
N_a = np.array([act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items()]).reshape(
pi_.shape[0], -1)
# print(N_a.shape)
max_index = np.argmax(pi_ - N_a)
# print((pi_ - N_a).shape)
return list(self._children.items())[max_index]
def update(self, leaf_value):
"""Update node values from leaf evaluation.
leaf_value: the value of subtree evaluation from the current player's
perspective.
"""
# Count visit.
self._n_visits += 1
# Update Q, a running average of values for all visits.
if opts.split == "train":
self._Q = self._Q + (1.0 * (leaf_value - self._Q) / self._n_visits)
else:
self._Q += (1.0 * (leaf_value - self._Q) / self._n_visits)
def update_recursive(self, leaf_value):
"""Like a call to update(), but applied recursively for all ancestors.
"""
# If it is not root, this node's parent should be updated first.
if self._parent:
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)
def get_pi(self, v_pi, max_N_b):
if self._n_visits == 0:
Q_completed = v_pi
else:
Q_completed = self._Q
return self._p + _sigma_mano(Q_completed, max_N_b)
def get_value(self, c_puct):
"""Calculate and return the value for this node.
It is a combination of leaf evaluations Q, and this node's prior
adjusted for its visit count, u.
c_puct: a number in (0, inf) controlling the relative impact of
value Q, and prior probability P, on this node's score.
"""
self._u = (c_puct * self._P *
np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
return self._Q + self._u
def is_leaf(self):
"""Check if leaf node (i.e. no nodes below this have been expanded)."""
return self._children == {}
def is_root(self):
return self._parent is None
class Gumbel_MCTS(object):
"""An implementation of Monte Carlo Tree Search."""
def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
"""
policy_value_fn: a function that takes in a board state and outputs
a list of (action, probability) tuples and also a score in [-1, 1]
(i.e. the expected value of the end game score from the current
player's perspective) for the current player.
c_puct: a number in (0, inf) that controls how quickly exploration
converges to the maximum-value policy. A higher value means
relying on the prior more.
"""
self._root = TreeNode(None, 1.0)
self._policy = policy_value_fn
self._c_puct = c_puct
self._n_playout = n_playout
def Gumbel_playout(self, child_node, child_state):
"""Run a single playout from the child of the root to the leaf, getting a value at
the leaf and propagating it back through its parents.
State is modified in-place, so a copy must be provided.
This mothod of select is a non-root selet.
"""
node = child_node
state = child_state
while (1):
if node.is_leaf():
break
# Greedily select next move.
action, node = node.select(node._v)
state.do_move(action)
# Evaluate the leaf using a network which outputs a list of
# (action, probability) tuples p and also a score v in [-1, 1]
# for the current player.
action_probs, leaf_value = self._policy(state)
# leaf_value = leaf_value.detach().numpy()[0][0]
leaf_value = leaf_value.detach().numpy()
node._v = leaf_value
# Check for end of game.
end, winner = state.game_end()
if not end:
node.expand(action_probs)
else:
# for end state,return the "true" leaf_value
if winner == -1: # tie
leaf_value = 0.0
else:
leaf_value = (
1.0 if winner == state.get_current_player() else -1.0
)
# Update value and visit count of nodes in this traversal.
node.update_recursive(-leaf_value)
def top_k(self, x, k):
# print("x",x.shape)
# print("k ", k)
return np.argpartition(x, k)[..., -k:]
def sample_k(self, logits, k):
u = np.random.uniform(size=np.shape(logits))
z = -np.log(-np.log(u))
return self.top_k(logits + z, k), z
def get_move_probs(self, state, temp=1e-3, m_action=16):
"""Run all playouts sequentially and return the available actions and
their corresponding probabilities.
state: the current game state
temp: temperature parameter in (0, 1] controls the level of exploration
"""
# 这里需要修改:1
# logits 暂定为 p
start_time = time.time()
# 对根节点进行拓展
act_probs, leaf_value = self._policy(state)
act_probs = list(act_probs)
# leaf_value = leaf_value.detach().numpy()[0][0]
leaf_value = leaf_value.detach().numpy()
# print(list(act_probs))
porbs = [prob for act, prob in (act_probs)]
self._root.expand(act_probs)
n = self._n_playout
m = min(m_action, int(len(porbs) / 2))
# 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
A_topm, g = self.sample_k(porbs, m)
# 获得state选取每个action后对应的状态,保存到一个列表中
root_childs = list(self._root._children.items())
child_state_m = []
for i in range(m):
state_copy = copy.deepcopy(state)
action, node = root_childs[A_topm[i]]
state_copy.do_move(action)
child_state_m.append(state_copy)
print(porbs)
print("depend on:", np.array(porbs)[A_topm])
print(f"A_topm_{m}", A_topm)
print("m ", m)
if m > 1:
# 每轮对选择的动作进行的仿真次数
N = int(n / (np.log(m) * m))
else:
N = n
# 进行sequential halving with Gumbel
while m >= 1:
# 对每个选择的动作进行仿真
for i in range(m):
action_state = child_state_m[i]
action, node = root_childs[A_topm[i]]
for j in range(N):
action_state_copy = copy.deepcopy(action_state)
# 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
self.Gumbel_playout(node, action_state_copy)
# 每轮不重复采样的动作个数减半
m = m // 2
# 不是最后一轮,单轮仿真次数加倍
if (m != 1):
n = n - N
N *= 2
# 当最后一轮时,只有一个动作,把所有仿真次数用完
else:
N = n
# 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
# print([action_node[1]._Q for action_node in self._root._children.items() ])
q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items()])
assert (np.sum(q_hat[A_topm] == 0) == 0)
print("depend on:", np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm])
print(f"A_topm_{m}", A_topm)
A_index = self.top_k(np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm], m)
A_topm = np.array(A_topm)[A_index]
child_state_m = np.array(child_state_m)[A_index]
# 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))
max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._root._children.items()]))
final_act_probs = softmax(
np.array([act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items()]))
action = (np.array([act_node[0] for act_node in self._root._children.items()]))
print("final_act_prbs", final_act_probs)
print("move :", action)
print("final_action", np.array(list(self._root._children.items()))[A_topm][0][0])
print("argmax_prob", np.argmax(final_act_probs))
need_time = time.time() - start_time
print(f" Gumbel Alphazero sum_time: {need_time}, total_simulation: {self._n_playout}")
return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs, need_time
def update_with_move(self, last_move):
"""Step forward in the tree, keeping everything we already know
about the subtree.
"""
if last_move in self._root._children:
self._root = self._root._children[last_move]
self._root._parent = None
else:
self._root = TreeNode(None, 1.0)
def __str__(self):
return "MCTS"
class Gumbel_MCTSPlayer(object):
"""AI player based on MCTS"""
def __init__(self, policy_value_function,
c_puct=5, n_playout=2000, is_selfplay=0, m_action=16):
self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
self._is_selfplay = is_selfplay
self.m_action = m_action
def set_player_ind(self, p):
self.player = p
def reset_player(self):
self.mcts.update_with_move(-1)
def get_action(self, board, temp=1e-3, return_prob=0, return_time=False):
sensible_moves = board.availables
# the pi vector returned by MCTS as in the alphaGo Zero paper
move_probs = np.zeros(board.width * board.height)
if len(sensible_moves) > 0:
# 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
move, acts, probs, simul_mean_time = self.mcts.get_move_probs(board, temp, self.m_action)
# 重置搜索树
self.mcts.update_with_move(-1)
move_probs[list(acts)] = probs
move_probs = np.zeros(move_probs.shape[0])
move_probs[move] = 1
print("final prob:", move_probs)
print("arg_max:", np.argmax(move_probs))
print("max", np.max(move_probs))
print("move", move)
# 他通过训练能够使得最后move_probs 有一个位置趋近于1,即得到一个策略
# 关键是他的策略,和MCTS得到move不一致,怀疑是分布策略计算的问题
if return_time:
if return_prob:
return move, move_probs, simul_mean_time
else:
return move, simul_mean_time
else:
if return_prob:
return move, move_probs
else:
return move
else:
print("WARNING: the board is full")
def __str__(self):
return "MCTS {}".format(self.player)
|