sjz commited on
Commit
9cefce7
1 Parent(s): aae2a37

update code

Browse files
Gomoku_MCTS/mcts_Gumbel_Alphazero.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FileName: main_worker.py
3
+ Author: Jiaxin Li
4
+ Create Date: 2023/11/21
5
+ Description: The implement of Gumbel MCST
6
+ Edit History:
7
+ Debug: the dim of output: probs
8
+ """
9
+
10
+ import numpy as np
11
+ import copy
12
+ import time
13
+
14
+ from config.options import *
15
+ import sys
16
+ from config.utils import *
17
+
18
+
19
+ def softmax(x):
20
+ probs = np.exp(x - np.max(x))
21
+ probs /= np.sum(probs)
22
+ return probs
23
+
24
+
25
+ def _sigma_mano(y ,Nb):
26
+ return (50 + Nb) * 1.0 * y
27
+
28
+
29
+ class TreeNode(object):
30
+ """A node in the MCTS tree.
31
+
32
+ Each node keeps track of its own value Q, prior probability P, and
33
+ its visit-count-adjusted prior score u.
34
+ """
35
+
36
+ def __init__(self, parent, prior_p):
37
+ self._parent = parent
38
+ self._children = {} # a map from action to TreeNode
39
+ self._n_visits = 0
40
+ self._Q = 0
41
+ self._u = 0
42
+ self._v = 0
43
+ self._p = prior_p
44
+
45
+
46
+
47
+ def expand(self, action_priors):
48
+ """Expand tree by creating new children.
49
+ action_priors: a list of tuples of actions and their prior probability
50
+ according to the policy function.
51
+ """
52
+ for action, prob in action_priors:
53
+ if action not in self._children:
54
+ self._children[action] = TreeNode(self, prob)
55
+
56
+
57
+ def select(self, v_pi):
58
+ """Select action among children that gives maximum
59
+ (pi'(a) - N(a) \ (1 + \sum_b N(b)))
60
+ Return: A tuple of (action, next_node)
61
+ """
62
+ # if opts.split == "train":
63
+ # v_pi = v_pi.detach().numpy()
64
+ # print(v_pi)
65
+
66
+
67
+
68
+
69
+ max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._children.items()]))
70
+
71
+ if opts.split == "train":
72
+ pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1)
73
+ else:
74
+ pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1)
75
+ # print(pi_.shape)
76
+
77
+
78
+ N_a = np.array( [ act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items() ]).reshape(pi_.shape[0],-1)
79
+ # print(N_a.shape)
80
+
81
+ max_index= np.argmax(pi_ - N_a)
82
+ # print((pi_ - N_a).shape)
83
+
84
+
85
+
86
+ return list(self._children.items())[max_index]
87
+
88
+
89
+ def update(self, leaf_value):
90
+ """Update node values from leaf evaluation.
91
+ leaf_value: the value of subtree evaluation from the current player's
92
+ perspective.
93
+ """
94
+ # Count visit.
95
+ self._n_visits += 1
96
+ # Update Q, a running average of values for all visits.
97
+ if opts.split == "train":
98
+ self._Q = self._Q + (1.0*(leaf_value - self._Q ) / self._n_visits)
99
+
100
+
101
+ else:
102
+ self._Q += (1.0*(leaf_value - self._Q) / self._n_visits)
103
+
104
+ def update_recursive(self, leaf_value):
105
+ """Like a call to update(), but applied recursively for all ancestors.
106
+ """
107
+ # If it is not root, this node's parent should be updated first.
108
+ if self._parent:
109
+ self._parent.update_recursive(-leaf_value)
110
+ self.update(leaf_value)
111
+
112
+ def get_pi(self,v_pi,max_N_b):
113
+ if self._n_visits == 0:
114
+ Q_completed = v_pi
115
+ else:
116
+ Q_completed = self._Q
117
+
118
+ return self._p + _sigma_mano(Q_completed,max_N_b)
119
+
120
+
121
+ def get_value(self, c_puct):
122
+ """Calculate and return the value for this node.
123
+ It is a combination of leaf evaluations Q, and this node's prior
124
+ adjusted for its visit count, u.
125
+ c_puct: a number in (0, inf) controlling the relative impact of
126
+ value Q, and prior probability P, on this node's score.
127
+ """
128
+ self._u = (c_puct * self._P *
129
+ np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
130
+ return self._Q + self._u
131
+
132
+ def is_leaf(self):
133
+ """Check if leaf node (i.e. no nodes below this have been expanded)."""
134
+ return self._children == {}
135
+
136
+ def is_root(self):
137
+ return self._parent is None
138
+
139
+
140
+ class Gumbel_MCTS(object):
141
+ """An implementation of Monte Carlo Tree Search."""
142
+
143
+ def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
144
+ """
145
+ policy_value_fn: a function that takes in a board state and outputs
146
+ a list of (action, probability) tuples and also a score in [-1, 1]
147
+ (i.e. the expected value of the end game score from the current
148
+ player's perspective) for the current player.
149
+ c_puct: a number in (0, inf) that controls how quickly exploration
150
+ converges to the maximum-value policy. A higher value means
151
+ relying on the prior more.
152
+ """
153
+ self._root = TreeNode(None, 1.0)
154
+ self._policy = policy_value_fn
155
+ self._c_puct = c_puct
156
+ self._n_playout = n_playout
157
+
158
+
159
+
160
+
161
+ def Gumbel_playout(self, child_node, child_state):
162
+ """Run a single playout from the child of the root to the leaf, getting a value at
163
+ the leaf and propagating it back through its parents.
164
+ State is modified in-place, so a copy must be provided.
165
+ This mothod of select is a non-root selet.
166
+ """
167
+ node = child_node
168
+ state = child_state
169
+
170
+ while(1):
171
+ if node.is_leaf():
172
+ break
173
+ # Greedily select next move.
174
+
175
+ action, node = node.select(node._v)
176
+
177
+ state.do_move(action)
178
+
179
+
180
+
181
+ # Evaluate the leaf using a network which outputs a list of
182
+ # (action, probability) tuples p and also a score v in [-1, 1]
183
+ # for the current player.
184
+ action_probs, leaf_value = self._policy(state)
185
+
186
+ leaf_value = leaf_value.detach().numpy()[0][0]
187
+
188
+ node._v = leaf_value
189
+
190
+
191
+
192
+ # Check for end of game.
193
+ end, winner = state.game_end()
194
+ if not end:
195
+ node.expand(action_probs)
196
+ else:
197
+ # for end state,return the "true" leaf_value
198
+ if winner == -1: # tie
199
+ leaf_value = 0.0
200
+ else:
201
+ leaf_value = (
202
+ 1.0 if winner == state.get_current_player() else -1.0
203
+ )
204
+
205
+ # Update value and visit count of nodes in this traversal.
206
+ node.update_recursive(-leaf_value)
207
+
208
+
209
+ def top_k(self,x, k):
210
+ # print("x",x.shape)
211
+ # print("k ", k)
212
+
213
+ return np.argpartition(x, k)[..., -k:]
214
+
215
+ def sample_k(self,logits, k):
216
+ u = np.random.uniform(size=np.shape(logits))
217
+ z = -np.log(-np.log(u))
218
+
219
+
220
+
221
+ return self.top_k(logits + z, k),z
222
+
223
+
224
+ def get_move_probs(self, state, temp=1e-3,m_action = 16):
225
+ """Run all playouts sequentially and return the available actions and
226
+ their corresponding probabilities.
227
+ state: the current game state
228
+ temp: temperature parameter in (0, 1] controls the level of exploration
229
+ """
230
+ # 这里需要修改:1
231
+ # logits 暂定为 p
232
+
233
+ start_time = time.time()
234
+
235
+
236
+ # 对根节点进行拓展
237
+ act_probs, leaf_value = self._policy(state)
238
+ act_probs = list(act_probs)
239
+
240
+ leaf_value = leaf_value.detach().numpy()[0][0]
241
+
242
+ # print(list(act_probs))
243
+ porbs = [prob for act,prob in (act_probs)]
244
+ self._root.expand(act_probs)
245
+
246
+
247
+ n = self._n_playout
248
+ m = min( m_action,int(len( porbs) / 2))
249
+
250
+
251
+ # 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
252
+ A_topm ,g = self.sample_k(porbs , m)
253
+
254
+ # 获得state选取每个action后对应的状态,保存到一个列表中
255
+ root_childs = list(self._root._children.items())
256
+
257
+
258
+ child_state_m = []
259
+ for i in range(m):
260
+ state_copy = copy.deepcopy(state)
261
+ action,node = root_childs[A_topm[i]]
262
+ state_copy.do_move(action)
263
+ child_state_m.append(state_copy)
264
+
265
+
266
+ # 每轮对选择的动作进行的仿真次数
267
+ N = int( n /( np.log(m) * m ))
268
+
269
+ # 进行sequential halving with Gumbel
270
+ while m >= 1:
271
+
272
+ # 对每个选择的动作进行仿真
273
+ for i in range(m):
274
+ action_state = child_state_m[i]
275
+
276
+ action,node = root_childs[A_topm[i]]
277
+
278
+ for j in range(N):
279
+ action_state_copy = copy.deepcopy(action_state)
280
+
281
+ # 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
282
+ self.Gumbel_playout(node, action_state_copy)
283
+
284
+ # 每轮不重复采样的动作个数减半
285
+ m = m //2
286
+
287
+ # 不是最后一轮,单轮仿真次数加倍
288
+ if(m != 1):
289
+ n = n - N
290
+ N *= 2
291
+ # 当最后一轮时,只有一个动作,把所有仿真次数用完
292
+ else:
293
+ N = n
294
+
295
+ # 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
296
+ # print([action_node[1]._Q for action_node in self._root._children.items() ])
297
+
298
+
299
+ q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items() ])
300
+
301
+
302
+ assert(np.sum(q_hat[A_topm] == 0) == 0 )
303
+
304
+ A_index = self.top_k( np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm] , m)
305
+ A_topm = np.array(A_topm)[A_index]
306
+ child_state_m = np.array(child_state_m)[A_index]
307
+
308
+
309
+ # 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))
310
+
311
+ max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._root._children.items()] ))
312
+
313
+ final_act_probs= softmax( np.array( [ act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items() ]))
314
+ action = ( np.array( [ act_node[0] for act_node in self._root._children.items() ]))
315
+
316
+ need_time = time.time() - start_time
317
+ print(f" Gumbel Alphazero sum_time: {need_time }, total_simulation: {self._n_playout}")
318
+
319
+ return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs , need_time
320
+
321
+ def update_with_move(self, last_move):
322
+ """Step forward in the tree, keeping everything we already know
323
+ about the subtree.
324
+ """
325
+ if last_move in self._root._children:
326
+ self._root = self._root._children[last_move]
327
+ self._root._parent = None
328
+ else:
329
+ self._root = TreeNode(None, 1.0)
330
+
331
+ def __str__(self):
332
+ return "MCTS"
333
+
334
+
335
+ class Gumbel_MCTSPlayer(object):
336
+ """AI player based on MCTS"""
337
+
338
+ def __init__(self, policy_value_function,
339
+ c_puct=5, n_playout=2000, is_selfplay=0,m_action = 16):
340
+ self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
341
+ self._is_selfplay = is_selfplay
342
+ self.m_action = m_action
343
+
344
+
345
+ def set_player_ind(self, p):
346
+ self.player = p
347
+
348
+ def reset_player(self):
349
+ self.mcts.update_with_move(-1)
350
+
351
+
352
+ def get_action(self, board, temp=1e-3, return_prob=0,return_time = False):
353
+ sensible_moves = board.availables
354
+ # the pi vector returned by MCTS as in the alphaGo Zero paper
355
+ move_probs = np.zeros(board.width*board.height)
356
+
357
+
358
+
359
+ if len(sensible_moves) > 0:
360
+
361
+ # 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
362
+ move, acts, probs,simul_mean_time = self.mcts.get_move_probs(board, temp,self.m_action)
363
+
364
+
365
+
366
+ # 重置搜索树
367
+ self.mcts.update_with_move(-1)
368
+
369
+ move_probs[list(acts)] = probs
370
+
371
+
372
+ if return_time:
373
+
374
+ if return_prob:
375
+
376
+ return move, move_probs,simul_mean_time
377
+ else:
378
+ return move,simul_mean_time
379
+ else:
380
+
381
+ if return_prob:
382
+
383
+ return move, move_probs
384
+ else:
385
+ return move
386
+ else:
387
+ print("WARNING: the board is full")
388
+
389
+ def __str__(self):
390
+ return "MCTS {}".format(self.player)
Gomoku_MCTS/mcts_alphaZero.py CHANGED
@@ -156,35 +156,39 @@ class MCTS(object):
156
  start_time_averge = 0
157
 
158
  ### test multi-thread
159
- lock = threading.Lock()
160
- with ThreadPoolExecutor(max_workers=4) as executor:
161
- for n in range(self._n_playout):
162
- start_time = time.time()
163
-
164
- state_copy = copy.deepcopy(state)
165
- executor.submit(self._playout, state_copy, lock)
166
- start_time_averge += (time.time() - start_time)
167
  ### end test multi-thread
168
 
169
- # t = time.time()
170
- # for n in range(self._n_playout):
171
- # start_time = time.time()
172
 
173
- # state_copy = copy.deepcopy(state)
174
- # self._playout(state_copy)
175
- # start_time_averge += (time.time() - start_time)
 
176
  # print('!!time!!:', time.time() - t)
177
 
178
- # print(f" My MCTS sum_time: {start_time_averge }, total_simulation: {self._n_playout}")
179
 
180
 
181
  # calc the move probabilities based on visit counts at the root node
182
  act_visits = [(act, node._n_visits)
183
  for act, node in self._root._children.items()]
 
184
  acts, visits = zip(*act_visits)
 
185
  act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
186
 
187
- return acts, act_probs
 
188
 
189
  def update_with_move(self, last_move):
190
  """Step forward in the tree, keeping everything we already know
@@ -214,12 +218,12 @@ class MCTSPlayer(object):
214
  def reset_player(self):
215
  self.mcts.update_with_move(-1)
216
 
217
- def get_action(self, board, temp=1e-3, return_prob=0):
218
  sensible_moves = board.availables
219
  # the pi vector returned by MCTS as in the alphaGo Zero paper
220
  move_probs = np.zeros(board.width*board.height)
221
  if len(sensible_moves) > 0:
222
- acts, probs = self.mcts.get_move_probs(board, temp)
223
  move_probs[list(acts)] = probs
224
  if self._is_selfplay:
225
  # add Dirichlet Noise for exploration (needed for
@@ -238,11 +242,22 @@ class MCTSPlayer(object):
238
  self.mcts.update_with_move(-1)
239
  # location = board.move_to_location(move)
240
  # print("AI move: %d,%d\n" % (location[0], location[1]))
 
 
 
241
 
242
- if return_prob:
243
- return move, move_probs
 
 
 
244
  else:
245
- return move
 
 
 
 
 
246
  else:
247
  print("WARNING: the board is full")
248
 
 
156
  start_time_averge = 0
157
 
158
  ### test multi-thread
159
+ # lock = threading.Lock()
160
+ # with ThreadPoolExecutor(max_workers=4) as executor:
161
+ # for n in range(self._n_playout):
162
+ # start_time = time.time()
163
+
164
+ # state_copy = copy.deepcopy(state)
165
+ # executor.submit(self._playout, state_copy, lock)
166
+ # start_time_averge += (time.time() - start_time)
167
  ### end test multi-thread
168
 
169
+ t = time.time()
170
+ for n in range(self._n_playout):
171
+ start_time = time.time()
172
 
173
+ state_copy = copy.deepcopy(state)
174
+ self._playout(state_copy)
175
+ start_time_averge += (time.time() - start_time)
176
+ total_time = time.time() - t
177
  # print('!!time!!:', time.time() - t)
178
 
179
+ print(f" My MCTS sum_time: {total_time }, total_simulation: {self._n_playout}")
180
 
181
 
182
  # calc the move probabilities based on visit counts at the root node
183
  act_visits = [(act, node._n_visits)
184
  for act, node in self._root._children.items()]
185
+
186
  acts, visits = zip(*act_visits)
187
+
188
  act_probs = softmax(1.0/temp * np.log(np.array(visits) + 1e-10))
189
 
190
+
191
+ return 0, acts, act_probs, total_time
192
 
193
  def update_with_move(self, last_move):
194
  """Step forward in the tree, keeping everything we already know
 
218
  def reset_player(self):
219
  self.mcts.update_with_move(-1)
220
 
221
+ def get_action(self, board, temp=1e-3, return_prob=0,return_time = False):
222
  sensible_moves = board.availables
223
  # the pi vector returned by MCTS as in the alphaGo Zero paper
224
  move_probs = np.zeros(board.width*board.height)
225
  if len(sensible_moves) > 0:
226
+ _, acts, probs, simul_mean_time = self.mcts.get_move_probs(board, temp)
227
  move_probs[list(acts)] = probs
228
  if self._is_selfplay:
229
  # add Dirichlet Noise for exploration (needed for
 
242
  self.mcts.update_with_move(-1)
243
  # location = board.move_to_location(move)
244
  # print("AI move: %d,%d\n" % (location[0], location[1]))
245
+
246
+
247
+ if return_time:
248
 
249
+ if return_prob:
250
+
251
+ return move, move_probs,simul_mean_time
252
+ else:
253
+ return move,simul_mean_time
254
  else:
255
+
256
+ if return_prob:
257
+
258
+ return move, move_probs
259
+ else:
260
+ return move
261
  else:
262
  print("WARNING: the board is full")
263
 
pages/Player_VS_AI.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
  FileName: app.py
3
  Author: Benhao Huang
4
- Create Date: 2023/11/18
5
  Description: this file is used to display our project and add visualization elements to the game, using Streamlit
6
  """
7
 
@@ -46,6 +46,7 @@ class Room:
46
  self.WINNER = _BLANK
47
  self.TIME = time.time()
48
  self.MCTS = MCTSpure(c_puct=5, n_playout=10)
 
49
  self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
50
  self.current_move = -1
51
  self.simula_time_list = []
@@ -60,6 +61,8 @@ if "ROOM" not in session_state:
60
  session_state.ROOM = Room("local")
61
  if "OWNER" not in session_state:
62
  session_state.OWNER = False
 
 
63
 
64
  # Check server health
65
  if "ROOMS" not in server_state:
@@ -88,6 +91,7 @@ MULTIPLAYER_TAG = st.sidebar.empty()
88
  with st.sidebar.container():
89
  ANOTHER_ROUND = st.empty()
90
  RESTART = st.empty()
 
91
  EXIT = st.empty()
92
  GAME_INFO = st.sidebar.container()
93
  message = st.empty()
@@ -213,6 +217,14 @@ def gomoku():
213
  winner = _BLANK
214
  return winner
215
 
 
 
 
 
 
 
 
 
216
  # Triggers the board response on click
217
  def handle_click(x, y):
218
  """
@@ -257,7 +269,11 @@ def gomoku():
257
  # Draw board
258
  def draw_board(response: bool):
259
  """construct each buttons for all cells of the board"""
260
-
 
 
 
 
261
  if response and session_state.ROOM.TURN == _BLACK: # human turn
262
  print("Your turn")
263
  # construction of clickable buttons
@@ -276,13 +292,23 @@ def gomoku():
276
  on_click=forbid_click
277
  )
278
  else:
279
- # enable click for other cells available for human choices
280
- BOARD_PLATE[i][j].button(
281
- _PLAYER_SYMBOL[cell],
282
- key=f"{i}:{j}",
283
- on_click=handle_click,
284
- args=(i, j),
285
- )
 
 
 
 
 
 
 
 
 
 
286
 
287
 
288
  elif response and session_state.ROOM.TURN == _WHITE: # AI turn
@@ -292,7 +318,7 @@ def gomoku():
292
  print("AI's turn")
293
  print("Below are current board under AI's view")
294
  print(session_state.ROOM.BOARD.board_map)
295
- move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD)
296
  session_state.ROOM.simula_time_list.append(simul_time)
297
  print("AI takes move: ", move)
298
  session_state.ROOM.current_move = move
@@ -321,13 +347,24 @@ def gomoku():
321
  on_click=forbid_click
322
  )
323
  else:
324
- # enable click for other cells available for human choices
325
- BOARD_PLATE[i][j].button(
326
- _PLAYER_SYMBOL[cell],
327
- key=f"{i}:{j}",
328
- on_click=handle_click,
329
- args=(i, j),
330
- )
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  message.markdown(
333
  'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
@@ -355,10 +392,17 @@ def gomoku():
355
  print("Game over")
356
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
357
  for j, cell in enumerate(row):
358
- BOARD_PLATE[i][j].write(
359
- _PLAYER_SYMBOL[cell],
360
- key=f"{i}:{j}",
361
- )
 
 
 
 
 
 
 
362
 
363
  # Game process control
364
  def game_control():
@@ -401,6 +445,11 @@ def gomoku():
401
  st.line_chart(chart_data)
402
 
403
  # The main game loop
 
 
 
 
 
404
  game_control()
405
  update_info()
406
 
 
1
  """
2
  FileName: app.py
3
  Author: Benhao Huang
4
+ Create Date: 2023/11/19
5
  Description: this file is used to display our project and add visualization elements to the game, using Streamlit
6
  """
7
 
 
46
  self.WINNER = _BLANK
47
  self.TIME = time.time()
48
  self.MCTS = MCTSpure(c_puct=5, n_playout=10)
49
+ self.MCTS = alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE).policy_value_fn, c_puct=5, n_playout=10)
50
  self.COORDINATE_1D = [_BOARD_SIZE_1D + 1]
51
  self.current_move = -1
52
  self.simula_time_list = []
 
61
  session_state.ROOM = Room("local")
62
  if "OWNER" not in session_state:
63
  session_state.OWNER = False
64
+ if "USE_AIAID" not in session_state:
65
+ session_state.USE_AIAID = False
66
 
67
  # Check server health
68
  if "ROOMS" not in server_state:
 
91
  with st.sidebar.container():
92
  ANOTHER_ROUND = st.empty()
93
  RESTART = st.empty()
94
+ AIAID = st.empty()
95
  EXIT = st.empty()
96
  GAME_INFO = st.sidebar.container()
97
  message = st.empty()
 
217
  winner = _BLANK
218
  return winner
219
 
220
+ def ai_aid() -> None:
221
+ """
222
+ Use AI Aid.
223
+ """
224
+ session_state.USE_AIAID = not session_state.USE_AIAID
225
+ print('Use AI Aid: ', session_state.USE_AIAID)
226
+ draw_board(False)
227
+
228
  # Triggers the board response on click
229
  def handle_click(x, y):
230
  """
 
269
  # Draw board
270
  def draw_board(response: bool):
271
  """construct each buttons for all cells of the board"""
272
+ if session_state.USE_AIAID:
273
+ _, acts, probs, simul_mean_time = session_state.ROOM.MCTS.mcts.get_move_probs(session_state.ROOM.BOARD)
274
+ sorted_acts_probs = sorted(zip(acts, probs), key=lambda x: x[1], reverse=True)
275
+ top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
276
+ top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
277
  if response and session_state.ROOM.TURN == _BLACK: # human turn
278
  print("Your turn")
279
  # construction of clickable buttons
 
292
  on_click=forbid_click
293
  )
294
  else:
295
+ if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
296
+ # enable click for other cells available for human choices
297
+ prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
298
+ BOARD_PLATE[i][j].button(
299
+ _PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
300
+ key=f"{i}:{j}",
301
+ on_click=handle_click,
302
+ args=(i, j),
303
+ )
304
+ else:
305
+ # enable click for other cells available for human choices
306
+ BOARD_PLATE[i][j].button(
307
+ _PLAYER_SYMBOL[cell],
308
+ key=f"{i}:{j}",
309
+ on_click=handle_click,
310
+ args=(i, j),
311
+ )
312
 
313
 
314
  elif response and session_state.ROOM.TURN == _WHITE: # AI turn
 
318
  print("AI's turn")
319
  print("Below are current board under AI's view")
320
  print(session_state.ROOM.BOARD.board_map)
321
+ move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
322
  session_state.ROOM.simula_time_list.append(simul_time)
323
  print("AI takes move: ", move)
324
  session_state.ROOM.current_move = move
 
347
  on_click=forbid_click
348
  )
349
  else:
350
+ if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
351
+ # enable click for other cells available for human choices
352
+ prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
353
+ BOARD_PLATE[i][j].button(
354
+ _PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
355
+ key=f"{i}:{j}",
356
+ on_click=handle_click,
357
+ args=(i, j),
358
+ )
359
+ else:
360
+ # enable click for other cells available for human choices
361
+ BOARD_PLATE[i][j].button(
362
+ _PLAYER_SYMBOL[cell],
363
+ key=f"{i}:{j}",
364
+ on_click=handle_click,
365
+ args=(i, j),
366
+ )
367
+
368
 
369
  message.markdown(
370
  'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
 
392
  print("Game over")
393
  for i, row in enumerate(session_state.ROOM.BOARD.board_map):
394
  for j, cell in enumerate(row):
395
+ if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
396
+ prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
397
+ BOARD_PLATE[i][j].write(
398
+ _PLAYER_SYMBOL[cell] + f"({round(prob, 2)})",
399
+ key=f"{i}:{j}",
400
+ )
401
+ else:
402
+ BOARD_PLATE[i][j].write(
403
+ _PLAYER_SYMBOL[cell],
404
+ key=f"{i}:{j}",
405
+ )
406
 
407
  # Game process control
408
  def game_control():
 
445
  st.line_chart(chart_data)
446
 
447
  # The main game loop
448
+ AIAID.button(
449
+ "Use AI Aid",
450
+ on_click=ai_aid,
451
+ help="Use AI Aid to help you make moves",
452
+ )
453
  game_control()
454
  update_info()
455