HuskyDoge commited on
Commit
beb9e09
1 Parent(s): 7d23b62
Gomoku_Bot/eval.py CHANGED
@@ -460,11 +460,9 @@ class Evaluate:
460
  for point in fives:
461
  x = point // self.size
462
  y = point % self.size
463
- model_train_matrix[x][y] = max(FIVE, model_train_matrix[x][y])
464
  for point in block_fives:
465
  x = point // self.size
466
  y = point % self.size
467
- model_train_matrix[x][y] = max(BLOCK_FIVE, model_train_matrix[x][y])
468
 
469
  return set(list(fives) + list(block_fives)), model_train_matrix
470
 
@@ -474,12 +472,10 @@ class Evaluate:
474
  for point in fours:
475
  x = point // self.size
476
  y = point % self.size
477
- model_train_matrix[x][y] = max(FOUR, model_train_matrix[x][y])
478
 
479
  for point in block_fours:
480
  x = point // self.size
481
  y = point % self.size
482
- model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
483
 
484
  return set(list(fours) + list(block_fours)), model_train_matrix
485
 
@@ -488,12 +484,10 @@ class Evaluate:
488
  for point in four_fours:
489
  x = point // self.size
490
  y = point % self.size
491
- model_train_matrix[x][y] = max(FOUR_FOUR, model_train_matrix[x][y])
492
 
493
  for point in block_fours:
494
  x = point // self.size
495
  y = point % self.size
496
- model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
497
 
498
  return set(list(four_fours) + list(block_fours)), model_train_matrix
499
 
@@ -504,17 +498,14 @@ class Evaluate:
504
  for point in four_threes:
505
  x = point // self.size
506
  y = point % self.size
507
- model_train_matrix[x][y] = max(FOUR_THREE, model_train_matrix[x][y])
508
 
509
  for point in block_fours:
510
  x = point // self.size
511
  y = point % self.size
512
- model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
513
 
514
  for point in threes:
515
  x = point // self.size
516
  y = point % self.size
517
- model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
518
 
519
  return set(list(four_threes) + list(block_fours) + list(threes)), model_train_matrix
520
 
@@ -524,17 +515,14 @@ class Evaluate:
524
  for point in three_threes:
525
  x = point // self.size
526
  y = point % self.size
527
- model_train_matrix[x][y] = max(THREE_THREE, model_train_matrix[x][y])
528
 
529
  for point in block_fours:
530
  x = point // self.size
531
  y = point % self.size
532
- model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
533
 
534
  for point in threes:
535
  x = point // self.size
536
  y = point % self.size
537
- model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
538
 
539
  return set(list(three_threes) + list(block_fours) + list(threes)), model_train_matrix
540
 
@@ -542,43 +530,16 @@ class Evaluate:
542
  for point in threes:
543
  x = point // self.size
544
  y = point % self.size
545
- model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
546
 
547
  for point in block_fours:
548
  x = point // self.size
549
  y = point % self.size
550
- model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
551
  return set(list(block_fours) + list(threes)), model_train_matrix
552
 
553
  block_threes = points[shapes['BLOCK_THREE']]
554
  two_twos = points[shapes['TWO_TWO']]
555
  twos = points[shapes['TWO']]
556
 
557
- for point in block_threes:
558
- x = point // self.size
559
- y = point % self.size
560
- model_train_matrix[x][y] = max(BLOCK_THREE, model_train_matrix[x][y])
561
-
562
- for point in two_twos:
563
- x = point // self.size
564
- y = point % self.size
565
- model_train_matrix[x][y] = max(TWO_TWO, model_train_matrix[x][y])
566
-
567
- for point in twos:
568
- x = point // self.size
569
- y = point % self.size
570
- model_train_matrix[x][y] = max(TWO, model_train_matrix[x][y])
571
-
572
- for point in block_fours:
573
- x = point // self.size
574
- y = point % self.size
575
- model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
576
-
577
- for point in threes:
578
- x = point // self.size
579
- y = point % self.size
580
- model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
581
-
582
  mid = list(block_fours) + list(threes) + list(block_threes) + list(two_twos) + list(twos)
583
  res = set(mid[:5])
584
  for i in range(len(model_train_matrix)):
 
460
  for point in fives:
461
  x = point // self.size
462
  y = point % self.size
 
463
  for point in block_fives:
464
  x = point // self.size
465
  y = point % self.size
 
466
 
467
  return set(list(fives) + list(block_fives)), model_train_matrix
468
 
 
472
  for point in fours:
473
  x = point // self.size
474
  y = point % self.size
 
475
 
476
  for point in block_fours:
477
  x = point // self.size
478
  y = point % self.size
 
479
 
480
  return set(list(fours) + list(block_fours)), model_train_matrix
481
 
 
484
  for point in four_fours:
485
  x = point // self.size
486
  y = point % self.size
 
487
 
488
  for point in block_fours:
489
  x = point // self.size
490
  y = point % self.size
 
491
 
492
  return set(list(four_fours) + list(block_fours)), model_train_matrix
493
 
 
498
  for point in four_threes:
499
  x = point // self.size
500
  y = point % self.size
 
501
 
502
  for point in block_fours:
503
  x = point // self.size
504
  y = point % self.size
 
505
 
506
  for point in threes:
507
  x = point // self.size
508
  y = point % self.size
 
509
 
510
  return set(list(four_threes) + list(block_fours) + list(threes)), model_train_matrix
511
 
 
515
  for point in three_threes:
516
  x = point // self.size
517
  y = point % self.size
 
518
 
519
  for point in block_fours:
520
  x = point // self.size
521
  y = point % self.size
 
522
 
523
  for point in threes:
524
  x = point // self.size
525
  y = point % self.size
 
526
 
527
  return set(list(three_threes) + list(block_fours) + list(threes)), model_train_matrix
528
 
 
530
  for point in threes:
531
  x = point // self.size
532
  y = point % self.size
 
533
 
534
  for point in block_fours:
535
  x = point // self.size
536
  y = point % self.size
 
537
  return set(list(block_fours) + list(threes)), model_train_matrix
538
 
539
  block_threes = points[shapes['BLOCK_THREE']]
540
  two_twos = points[shapes['TWO_TWO']]
541
  twos = points[shapes['TWO']]
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  mid = list(block_fours) + list(threes) + list(block_threes) + list(two_twos) + list(twos)
544
  res = set(mid[:5])
545
  for i in range(len(model_train_matrix)):
Gomoku_Bot/gomoku_bot.py CHANGED
@@ -3,7 +3,7 @@ import time
3
 
4
 
5
  class Gomoku_bot:
6
- def __init__(self, board, role, depth=4, enableVCT=True):
7
  self.board = board
8
  self.role = role
9
  self.depth = depth
@@ -14,7 +14,8 @@ class Gomoku_bot:
14
  score = minmax(self.board, self.role, self.depth, self.enableVCT)
15
  end = time.time()
16
  sim_time = end - start
17
- move = score[1]
 
18
  # turn tuple into an int
19
  move = move[0] * self.board.size + move[1]
20
  if return_time:
 
3
 
4
 
5
  class Gomoku_bot:
6
+ def __init__(self, board, role, depth=4, enableVCT=False):
7
  self.board = board
8
  self.role = role
9
  self.depth = depth
 
14
  score = minmax(self.board, self.role, self.depth, self.enableVCT)
15
  end = time.time()
16
  sim_time = end - start
17
+ move = score[1] # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0)
18
+ move = (self.board.size - 1 - move[0], move[1]) # convert the move to the game's coordinate
19
  # turn tuple into an int
20
  move = move[0] * self.board.size + move[1]
21
  if return_time:
Gomoku_MCTS/__init__.py CHANGED
@@ -1,7 +1,9 @@
1
  from .mcts_pure import MCTSPlayer as MCTSpure
2
  from .mcts_alphaZero import MCTSPlayer as alphazero
3
- # from .dueling_net import PolicyValueNet
4
- from .policy_value_net_pytorch import PolicyValueNet
 
 
5
  import numpy as np
6
 
7
 
 
1
  from .mcts_pure import MCTSPlayer as MCTSpure
2
  from .mcts_alphaZero import MCTSPlayer as alphazero
3
+ from .policy_value_net_pytorch import PolicyValueNet as PolicyValueNet_old
4
+ from .policy_value_net_pytorch_new import PolicyValueNet as PolicyValueNet_new
5
+ from .dueling_net import PolicyValueNet as duel_PolicyValueNet
6
+ from .mcts_Gumbel_Alphazero import Gumbel_MCTSPlayer
7
  import numpy as np
8
 
9
 
Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/best_policy.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6278cd8f69e66f42a927df96e1fe3952a6e1e5a41e37f99f315b9bc3febd6d7a
3
+ size 529974
Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/current_policy.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab24465e4c52e038fda2bafae71550cb89c3f42f59d2ebf12b7a45c2c353eb33
3
+ size 530034
Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d43bc56e0bc86c5548d7857b6777b1971d2d14da6f344280b6f81be8595ac710
3
+ size 555837
Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/current_policy.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ab8cee4afa72bf29d73d0707ad1b92e1d0a24a009721c110c99b2ff5d2f866f
3
+ size 556110
Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8a225a330990d289278d8fa2cbd8bda0a7ea541c3ffd7aac6327d4553ef8683
3
+ size 555837
Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/current_policy.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dfdc692e55ba43f6ae714da93ee6182b9adcf8824c292bb597ef0d003c6d10b
3
+ size 556110
Gomoku_MCTS/config/utils.py CHANGED
@@ -1,7 +1,7 @@
1
  import os, shutil
2
  import torch
3
- from tensorboardX import SummaryWriter
4
- from config.options import *
5
  import torch.distributed as dist
6
  import time
7
 
@@ -42,13 +42,13 @@ def makedir(path):
42
  os.makedirs(path, 0o777)
43
 
44
 
45
- def visualizer():
46
- if get_rank() == 0:
47
- # filewriter_path = config['visual_base']+opts.savepath+'/'
48
- save_path = make_path()
49
- filewriter_path = os.path.join(config['visual_base'], save_path)
50
- if opts.clear_visualizer and os.path.exists(filewriter_path): # 删掉以前的summary,以免重合
51
- shutil.rmtree(filewriter_path)
52
- makedir(filewriter_path)
53
- writer = SummaryWriter(filewriter_path, comment='visualizer')
54
- return writer
 
1
  import os, shutil
2
  import torch
3
+ # from tensorboardX import SummaryWriter
4
+ from .options import *
5
  import torch.distributed as dist
6
  import time
7
 
 
42
  os.makedirs(path, 0o777)
43
 
44
 
45
+ # def visualizer():
46
+ # if get_rank() == 0:
47
+ # # filewriter_path = config['visual_base']+opts.savepath+'/'
48
+ # save_path = make_path()
49
+ # filewriter_path = os.path.join(config['visual_base'], save_path)
50
+ # if opts.clear_visualizer and os.path.exists(filewriter_path): # 删掉以前的summary,以免重合
51
+ # shutil.rmtree(filewriter_path)
52
+ # makedir(filewriter_path)
53
+ # writer = SummaryWriter(filewriter_path, comment='visualizer')
54
+ # return writer
Gomoku_MCTS/dueling_net.py CHANGED
@@ -52,16 +52,16 @@ class DuelingDQNNet(nn.Module):
52
  return F.log_softmax(q_values, dim=1), val
53
 
54
  class PolicyValueNet():
55
- """dueling policy-value network """
56
  def __init__(self, board_width, board_height,
57
- model_file=None, use_gpu=False):
58
  self.use_gpu = use_gpu
59
  self.board_width = board_width
60
  self.board_height = board_height
61
  self.l2_const = 1e-4 # coef of l2 penalty
62
  # the policy value net module
63
  if self.use_gpu:
64
- self.policy_value_net = DuelingDQNNet(board_width, board_height).cuda()
65
  else:
66
  self.policy_value_net = DuelingDQNNet(board_width, board_height)
67
  self.optimizer = optim.Adam(self.policy_value_net.parameters(),
@@ -70,7 +70,6 @@ class PolicyValueNet():
70
  if model_file:
71
  net_params = torch.load(model_file)
72
  self.policy_value_net.load_state_dict(net_params, strict=False)
73
- print('loaded dueling model file')
74
 
75
  def policy_value(self, state_batch):
76
  """
@@ -78,7 +77,7 @@ class PolicyValueNet():
78
  output: a batch of action probabilities and state values
79
  """
80
  if self.use_gpu:
81
- state_batch = Variable(torch.FloatTensor(state_batch).cuda())
82
  log_act_probs, value = self.policy_value_net(state_batch)
83
  act_probs = np.exp(log_act_probs.data.cpu().numpy())
84
  return act_probs, value.data.cpu().numpy()
@@ -97,16 +96,20 @@ class PolicyValueNet():
97
  legal_positions = board.availables
98
  current_state = np.ascontiguousarray(board.current_state().reshape(
99
  -1, 4, self.board_width, self.board_height))
 
100
  if self.use_gpu:
101
  log_act_probs, value = self.policy_value_net(
102
- Variable(torch.from_numpy(current_state)).cuda().float())
103
  act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
104
  else:
105
  log_act_probs, value = self.policy_value_net(
106
  Variable(torch.from_numpy(current_state)).float())
107
  act_probs = np.exp(log_act_probs.data.numpy().flatten())
 
108
  act_probs = zip(legal_positions, act_probs[legal_positions])
109
- value = value.data[0][0]
 
 
110
  return act_probs, value
111
 
112
  def train_step(self, state_batch, mcts_probs, winner_batch, lr):
@@ -115,9 +118,9 @@ class PolicyValueNet():
115
  # self.use_gpu = True
116
  # wrap in Variable
117
  if self.use_gpu:
118
- state_batch = Variable(torch.FloatTensor(state_batch).cuda())
119
- mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda())
120
- winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
121
  else:
122
  state_batch = Variable(torch.FloatTensor(state_batch))
123
  mcts_probs = Variable(torch.FloatTensor(mcts_probs))
 
52
  return F.log_softmax(q_values, dim=1), val
53
 
54
  class PolicyValueNet():
55
+ """policy-value network """
56
  def __init__(self, board_width, board_height,
57
+ model_file=None, use_gpu=False, device = None):
58
  self.use_gpu = use_gpu
59
  self.board_width = board_width
60
  self.board_height = board_height
61
  self.l2_const = 1e-4 # coef of l2 penalty
62
  # the policy value net module
63
  if self.use_gpu:
64
+ self.policy_value_net = DuelingDQNNet(board_width, board_height).to(device)
65
  else:
66
  self.policy_value_net = DuelingDQNNet(board_width, board_height)
67
  self.optimizer = optim.Adam(self.policy_value_net.parameters(),
 
70
  if model_file:
71
  net_params = torch.load(model_file)
72
  self.policy_value_net.load_state_dict(net_params, strict=False)
 
73
 
74
  def policy_value(self, state_batch):
75
  """
 
77
  output: a batch of action probabilities and state values
78
  """
79
  if self.use_gpu:
80
+ state_batch = Variable(torch.FloatTensor(state_batch).to(device))
81
  log_act_probs, value = self.policy_value_net(state_batch)
82
  act_probs = np.exp(log_act_probs.data.cpu().numpy())
83
  return act_probs, value.data.cpu().numpy()
 
96
  legal_positions = board.availables
97
  current_state = np.ascontiguousarray(board.current_state().reshape(
98
  -1, 4, self.board_width, self.board_height))
99
+
100
  if self.use_gpu:
101
  log_act_probs, value = self.policy_value_net(
102
+ Variable(torch.from_numpy(current_state)).to(device).float())
103
  act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
104
  else:
105
  log_act_probs, value = self.policy_value_net(
106
  Variable(torch.from_numpy(current_state)).float())
107
  act_probs = np.exp(log_act_probs.data.numpy().flatten())
108
+
109
  act_probs = zip(legal_positions, act_probs[legal_positions])
110
+
111
+
112
+
113
  return act_probs, value
114
 
115
  def train_step(self, state_batch, mcts_probs, winner_batch, lr):
 
118
  # self.use_gpu = True
119
  # wrap in Variable
120
  if self.use_gpu:
121
+ state_batch = Variable(torch.FloatTensor(state_batch).to(device))
122
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(device))
123
+ winner_batch = Variable(torch.FloatTensor(winner_batch).to(device))
124
  else:
125
  state_batch = Variable(torch.FloatTensor(state_batch))
126
  mcts_probs = Variable(torch.FloatTensor(mcts_probs))
Gomoku_MCTS/mcts_Gumbel_Alphazero.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- FileName: main_worker.py
3
  Author: Jiaxin Li
4
  Create Date: 2023/11/21
5
  Description: The implement of Gumbel MCST
@@ -9,11 +9,11 @@ Debug: the dim of output: probs
9
 
10
  import numpy as np
11
  import copy
12
- import time
13
 
14
- from config.options import *
15
  import sys
16
- from config.utils import *
17
 
18
 
19
  def softmax(x):
@@ -22,8 +22,8 @@ def softmax(x):
22
  return probs
23
 
24
 
25
- def _sigma_mano(y ,Nb):
26
- return (50 + Nb) * 1.0 * y
27
 
28
 
29
  class TreeNode(object):
@@ -42,8 +42,6 @@ class TreeNode(object):
42
  self._v = 0
43
  self._p = prior_p
44
 
45
-
46
-
47
  def expand(self, action_priors):
48
  """Expand tree by creating new children.
49
  action_priors: a list of tuples of actions and their prior probability
@@ -52,7 +50,6 @@ class TreeNode(object):
52
  for action, prob in action_priors:
53
  if action not in self._children:
54
  self._children[action] = TreeNode(self, prob)
55
-
56
 
57
  def select(self, v_pi):
58
  """Select action among children that gives maximum
@@ -62,29 +59,25 @@ class TreeNode(object):
62
  # if opts.split == "train":
63
  # v_pi = v_pi.detach().numpy()
64
  # print(v_pi)
65
-
66
-
67
-
68
 
69
- max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._children.items()]))
70
 
71
  if opts.split == "train":
72
- pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1)
 
73
  else:
74
- pi_ = softmax( np.array( [ act_node[1].get_pi(v_pi,max_N_b) for act_node in self._children.items() ])).reshape(len(list(self._children.items())) ,-1)
 
75
  # print(pi_.shape)
76
-
77
 
78
- N_a = np.array( [ act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items() ]).reshape(pi_.shape[0],-1)
 
79
  # print(N_a.shape)
80
 
81
- max_index= np.argmax(pi_ - N_a)
82
  # print((pi_ - N_a).shape)
83
-
84
-
85
-
86
- return list(self._children.items())[max_index]
87
 
 
88
 
89
  def update(self, leaf_value):
90
  """Update node values from leaf evaluation.
@@ -95,11 +88,9 @@ class TreeNode(object):
95
  self._n_visits += 1
96
  # Update Q, a running average of values for all visits.
97
  if opts.split == "train":
98
- self._Q = self._Q + (1.0*(leaf_value - self._Q ) / self._n_visits)
99
-
100
-
101
- else:
102
- self._Q += (1.0*(leaf_value - self._Q) / self._n_visits)
103
 
104
  def update_recursive(self, leaf_value):
105
  """Like a call to update(), but applied recursively for all ancestors.
@@ -109,14 +100,13 @@ class TreeNode(object):
109
  self._parent.update_recursive(-leaf_value)
110
  self.update(leaf_value)
111
 
112
- def get_pi(self,v_pi,max_N_b):
113
  if self._n_visits == 0:
114
  Q_completed = v_pi
115
  else:
116
  Q_completed = self._Q
117
-
118
- return self._p + _sigma_mano(Q_completed,max_N_b)
119
 
 
120
 
121
  def get_value(self, c_puct):
122
  """Calculate and return the value for this node.
@@ -155,9 +145,6 @@ class Gumbel_MCTS(object):
155
  self._c_puct = c_puct
156
  self._n_playout = n_playout
157
 
158
-
159
-
160
-
161
  def Gumbel_playout(self, child_node, child_state):
162
  """Run a single playout from the child of the root to the leaf, getting a value at
163
  the leaf and propagating it back through its parents.
@@ -166,28 +153,26 @@ class Gumbel_MCTS(object):
166
  """
167
  node = child_node
168
  state = child_state
169
-
170
- while(1):
171
  if node.is_leaf():
172
  break
173
  # Greedily select next move.
174
 
175
  action, node = node.select(node._v)
176
-
177
- state.do_move(action)
178
-
179
 
 
180
 
181
  # Evaluate the leaf using a network which outputs a list of
182
  # (action, probability) tuples p and also a score v in [-1, 1]
183
  # for the current player.
 
184
  action_probs, leaf_value = self._policy(state)
185
-
186
- leaf_value = leaf_value.detach().numpy()[0][0]
187
 
188
- node._v = leaf_value
189
-
190
 
 
191
 
192
  # Check for end of game.
193
  end, winner = state.game_end()
@@ -204,24 +189,19 @@ class Gumbel_MCTS(object):
204
 
205
  # Update value and visit count of nodes in this traversal.
206
  node.update_recursive(-leaf_value)
207
-
208
 
209
- def top_k(self,x, k):
210
  # print("x",x.shape)
211
  # print("k ", k)
212
 
213
  return np.argpartition(x, k)[..., -k:]
214
 
215
- def sample_k(self,logits, k):
216
  u = np.random.uniform(size=np.shape(logits))
217
  z = -np.log(-np.log(u))
 
218
 
219
-
220
-
221
- return self.top_k(logits + z, k),z
222
-
223
-
224
- def get_move_probs(self, state, temp=1e-3,m_action = 16):
225
  """Run all playouts sequentially and return the available actions and
226
  their corresponding probabilities.
227
  state: the current game state
@@ -231,92 +211,102 @@ class Gumbel_MCTS(object):
231
  # logits 暂定为 p
232
 
233
  start_time = time.time()
234
-
235
-
236
  # 对根节点进行拓展
 
237
  act_probs, leaf_value = self._policy(state)
238
- act_probs = list(act_probs)
239
 
240
- leaf_value = leaf_value.detach().numpy()[0][0]
241
-
 
 
242
  # print(list(act_probs))
243
- porbs = [prob for act,prob in (act_probs)]
244
- self._root.expand(act_probs)
245
 
 
246
 
247
  n = self._n_playout
248
- m = min( m_action,int(len( porbs) / 2))
249
-
250
 
251
  # 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
252
- A_topm ,g = self.sample_k(porbs , m)
253
-
254
  # 获得state选取每个action后对应的状态,保存到一个列表中
255
  root_childs = list(self._root._children.items())
256
-
257
 
258
  child_state_m = []
259
  for i in range(m):
260
  state_copy = copy.deepcopy(state)
261
- action,node = root_childs[A_topm[i]]
262
  state_copy.do_move(action)
263
  child_state_m.append(state_copy)
 
 
 
 
264
 
265
-
266
- # 每轮对选择的动作进行的仿真次数
267
- N = int( n /( np.log(m) * m ))
 
 
 
 
268
 
269
  # 进行sequential halving with Gumbel
270
  while m >= 1:
271
-
272
  # 对每个选择的动作进行仿真
273
  for i in range(m):
274
  action_state = child_state_m[i]
275
-
276
- action,node = root_childs[A_topm[i]]
277
-
278
  for j in range(N):
279
  action_state_copy = copy.deepcopy(action_state)
280
-
281
  # 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
282
  self.Gumbel_playout(node, action_state_copy)
283
 
284
  # 每轮不重复采样的动作个数减半
285
- m = m //2
286
 
287
  # 不是最后一轮,单轮仿真次数加倍
288
- if(m != 1):
289
  n = n - N
290
  N *= 2
291
  # 当最后一轮时,只有一个动作,把所有仿真次数用完
292
  else:
293
  N = n
294
-
295
  # 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
296
  # print([action_node[1]._Q for action_node in self._root._children.items() ])
297
-
298
-
299
- q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items() ])
300
-
301
 
302
- assert(np.sum(q_hat[A_topm] == 0) == 0 )
 
 
 
 
303
 
304
- A_index = self.top_k( np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm] , m)
305
  A_topm = np.array(A_topm)[A_index]
306
  child_state_m = np.array(child_state_m)[A_index]
307
-
308
-
309
  # 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))
310
 
311
- max_N_b = np.max(np.array( [act_node[1]._n_visits for act_node in self._root._children.items()] ))
312
 
313
- final_act_probs= softmax( np.array( [ act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items() ]))
314
- action = ( np.array( [ act_node[0] for act_node in self._root._children.items() ]))
315
 
 
 
 
 
 
316
  need_time = time.time() - start_time
317
- print(f" Gumbel Alphazero sum_time: {need_time }, total_simulation: {self._n_playout}")
318
 
319
- return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs , need_time
320
 
321
  def update_with_move(self, last_move):
322
  """Step forward in the tree, keeping everything we already know
@@ -336,50 +326,53 @@ class Gumbel_MCTSPlayer(object):
336
  """AI player based on MCTS"""
337
 
338
  def __init__(self, policy_value_function,
339
- c_puct=5, n_playout=2000, is_selfplay=0,m_action = 16):
340
  self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
341
  self._is_selfplay = is_selfplay
342
  self.m_action = m_action
343
 
344
-
345
  def set_player_ind(self, p):
346
  self.player = p
347
 
348
  def reset_player(self):
349
  self.mcts.update_with_move(-1)
350
 
351
-
352
- def get_action(self, board, temp=1e-3, return_prob=0,return_time = False):
353
  sensible_moves = board.availables
354
  # the pi vector returned by MCTS as in the alphaGo Zero paper
355
- move_probs = np.zeros(board.width*board.height)
356
-
357
-
358
-
359
  if len(sensible_moves) > 0:
360
 
361
  # 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
362
- move, acts, probs,simul_mean_time = self.mcts.get_move_probs(board, temp,self.m_action)
363
-
364
-
365
 
366
  # 重置搜索树
367
  self.mcts.update_with_move(-1)
368
 
369
  move_probs[list(acts)] = probs
 
 
 
 
 
 
 
370
 
 
 
371
 
372
  if return_time:
373
 
374
  if return_prob:
375
-
376
- return move, move_probs,simul_mean_time
377
  else:
378
- return move,simul_mean_time
379
  else:
380
 
381
  if return_prob:
382
-
383
  return move, move_probs
384
  else:
385
  return move
 
1
  """
2
+ FileName: mcts_Gumbel_Alphazero.py
3
  Author: Jiaxin Li
4
  Create Date: 2023/11/21
5
  Description: The implement of Gumbel MCST
 
9
 
10
  import numpy as np
11
  import copy
12
+ import time
13
 
14
+ from .config.options import *
15
  import sys
16
+ from .config.utils import *
17
 
18
 
19
  def softmax(x):
 
22
  return probs
23
 
24
 
25
+ def _sigma_mano(y, Nb):
26
+ return (50 + Nb) * 1.0 * y
27
 
28
 
29
  class TreeNode(object):
 
42
  self._v = 0
43
  self._p = prior_p
44
 
 
 
45
  def expand(self, action_priors):
46
  """Expand tree by creating new children.
47
  action_priors: a list of tuples of actions and their prior probability
 
50
  for action, prob in action_priors:
51
  if action not in self._children:
52
  self._children[action] = TreeNode(self, prob)
 
53
 
54
  def select(self, v_pi):
55
  """Select action among children that gives maximum
 
59
  # if opts.split == "train":
60
  # v_pi = v_pi.detach().numpy()
61
  # print(v_pi)
 
 
 
62
 
63
+ max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._children.items()]))
64
 
65
  if opts.split == "train":
66
+ pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
67
+ len(list(self._children.items())), -1)
68
  else:
69
+ pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
70
+ len(list(self._children.items())), -1)
71
  # print(pi_.shape)
 
72
 
73
+ N_a = np.array([act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items()]).reshape(
74
+ pi_.shape[0], -1)
75
  # print(N_a.shape)
76
 
77
+ max_index = np.argmax(pi_ - N_a)
78
  # print((pi_ - N_a).shape)
 
 
 
 
79
 
80
+ return list(self._children.items())[max_index]
81
 
82
  def update(self, leaf_value):
83
  """Update node values from leaf evaluation.
 
88
  self._n_visits += 1
89
  # Update Q, a running average of values for all visits.
90
  if opts.split == "train":
91
+ self._Q = self._Q + (1.0 * (leaf_value - self._Q) / self._n_visits)
92
+ else:
93
+ self._Q += (1.0 * (leaf_value - self._Q) / self._n_visits)
 
 
94
 
95
  def update_recursive(self, leaf_value):
96
  """Like a call to update(), but applied recursively for all ancestors.
 
100
  self._parent.update_recursive(-leaf_value)
101
  self.update(leaf_value)
102
 
103
+ def get_pi(self, v_pi, max_N_b):
104
  if self._n_visits == 0:
105
  Q_completed = v_pi
106
  else:
107
  Q_completed = self._Q
 
 
108
 
109
+ return self._p + _sigma_mano(Q_completed, max_N_b)
110
 
111
  def get_value(self, c_puct):
112
  """Calculate and return the value for this node.
 
145
  self._c_puct = c_puct
146
  self._n_playout = n_playout
147
 
 
 
 
148
  def Gumbel_playout(self, child_node, child_state):
149
  """Run a single playout from the child of the root to the leaf, getting a value at
150
  the leaf and propagating it back through its parents.
 
153
  """
154
  node = child_node
155
  state = child_state
156
+
157
+ while (1):
158
  if node.is_leaf():
159
  break
160
  # Greedily select next move.
161
 
162
  action, node = node.select(node._v)
 
 
 
163
 
164
+ state.do_move(action)
165
 
166
  # Evaluate the leaf using a network which outputs a list of
167
  # (action, probability) tuples p and also a score v in [-1, 1]
168
  # for the current player.
169
+
170
  action_probs, leaf_value = self._policy(state)
 
 
171
 
172
+ # leaf_value = leaf_value.detach().numpy()[0][0]
173
+ leaf_value = leaf_value.detach().numpy()
174
 
175
+ node._v = leaf_value
176
 
177
  # Check for end of game.
178
  end, winner = state.game_end()
 
189
 
190
  # Update value and visit count of nodes in this traversal.
191
  node.update_recursive(-leaf_value)
 
192
 
193
+ def top_k(self, x, k):
194
  # print("x",x.shape)
195
  # print("k ", k)
196
 
197
  return np.argpartition(x, k)[..., -k:]
198
 
199
+ def sample_k(self, logits, k):
200
  u = np.random.uniform(size=np.shape(logits))
201
  z = -np.log(-np.log(u))
202
+ return self.top_k(logits + z, k), z
203
 
204
+ def get_move_probs(self, state, temp=1e-3, m_action=16):
 
 
 
 
 
205
  """Run all playouts sequentially and return the available actions and
206
  their corresponding probabilities.
207
  state: the current game state
 
211
  # logits 暂定为 p
212
 
213
  start_time = time.time()
 
 
214
  # 对根节点进行拓展
215
+
216
  act_probs, leaf_value = self._policy(state)
 
217
 
218
+ act_probs = list(act_probs)
219
+
220
+ # leaf_value = leaf_value.detach().numpy()[0][0]
221
+ leaf_value = leaf_value.detach().numpy()
222
  # print(list(act_probs))
223
+ porbs = [prob for act, prob in (act_probs)]
 
224
 
225
+ self._root.expand(act_probs)
226
 
227
  n = self._n_playout
228
+ m = min(m_action, int(len(porbs) / 2))
 
229
 
230
  # 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
231
+ A_topm, g = self.sample_k(porbs, m)
232
+
233
  # 获得state选取每个action后对应的状态,保存到一个列表中
234
  root_childs = list(self._root._children.items())
 
235
 
236
  child_state_m = []
237
  for i in range(m):
238
  state_copy = copy.deepcopy(state)
239
+ action, node = root_childs[A_topm[i]]
240
  state_copy.do_move(action)
241
  child_state_m.append(state_copy)
242
+ print(porbs)
243
+
244
+ print("depend on:", np.array(porbs)[A_topm])
245
+ print(f"A_topm_{m}", A_topm)
246
 
247
+ print("m ", m)
248
+
249
+ if m > 1:
250
+ # 每轮对选择的动作进行的仿真次数
251
+ N = int(n / (np.log(m) * m))
252
+ else:
253
+ N = n
254
 
255
  # 进行sequential halving with Gumbel
256
  while m >= 1:
257
+
258
  # 对每个选择的动作进行仿真
259
  for i in range(m):
260
  action_state = child_state_m[i]
261
+
262
+ action, node = root_childs[A_topm[i]]
263
+
264
  for j in range(N):
265
  action_state_copy = copy.deepcopy(action_state)
266
+
267
  # 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
268
  self.Gumbel_playout(node, action_state_copy)
269
 
270
  # 每轮不重复采样的动作个数减半
271
+ m = m // 2
272
 
273
  # 不是最后一轮,单轮仿真次数加倍
274
+ if (m != 1):
275
  n = n - N
276
  N *= 2
277
  # 当最后一轮时,只有一个动作,把所有仿真次数用完
278
  else:
279
  N = n
280
+
281
  # 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
282
  # print([action_node[1]._Q for action_node in self._root._children.items() ])
 
 
 
 
283
 
284
+ q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items()])
285
+ assert (np.sum(q_hat[A_topm] == 0) == 0)
286
+
287
+ print("depend on:", np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm])
288
+ print(f"A_topm_{m}", A_topm)
289
 
290
+ A_index = self.top_k(np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm], m)
291
  A_topm = np.array(A_topm)[A_index]
292
  child_state_m = np.array(child_state_m)[A_index]
293
+
 
294
  # 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))
295
 
296
+ max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._root._children.items()]))
297
 
298
+ final_act_probs = softmax(
299
+ np.array([act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items()]))
300
 
301
+ action = (np.array([act_node[0] for act_node in self._root._children.items()]))
302
+ print("final_act_prbs", final_act_probs)
303
+ print("move :", action)
304
+ print("final_action", np.array(list(self._root._children.items()))[A_topm][0][0])
305
+ print("argmax_prob", np.argmax(final_act_probs))
306
  need_time = time.time() - start_time
307
+ print(f" Gumbel Alphazero sum_time: {need_time}, total_simulation: {self._n_playout}")
308
 
309
+ return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs, need_time
310
 
311
  def update_with_move(self, last_move):
312
  """Step forward in the tree, keeping everything we already know
 
326
  """AI player based on MCTS"""
327
 
328
  def __init__(self, policy_value_function,
329
+ c_puct=5, n_playout=2000, is_selfplay=0, m_action=16):
330
  self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
331
  self._is_selfplay = is_selfplay
332
  self.m_action = m_action
333
 
 
334
  def set_player_ind(self, p):
335
  self.player = p
336
 
337
  def reset_player(self):
338
  self.mcts.update_with_move(-1)
339
 
340
+ def get_action(self, board, temp=1e-3, return_prob=0, return_time=False):
 
341
  sensible_moves = board.availables
342
  # the pi vector returned by MCTS as in the alphaGo Zero paper
343
+ move_probs = np.zeros(board.width * board.height)
344
+
 
 
345
  if len(sensible_moves) > 0:
346
 
347
  # 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
348
+ move, acts, probs, simul_mean_time = self.mcts.get_move_probs(board, temp, self.m_action)
 
 
349
 
350
  # 重置搜索树
351
  self.mcts.update_with_move(-1)
352
 
353
  move_probs[list(acts)] = probs
354
+ move_probs = np.zeros(move_probs.shape[0])
355
+ move_probs[move] = 1
356
+
357
+ print("final prob:", move_probs)
358
+ print("arg_max:", np.argmax(move_probs))
359
+ print("max", np.max(move_probs))
360
+ print("move", move)
361
 
362
+ # 他通过训练能够使得最后move_probs 有一个位置趋近于1,即得到一个策略
363
+ # 关键是他的策略,和MCTS得到move不一致,怀疑是分布策略计算的问题
364
 
365
  if return_time:
366
 
367
  if return_prob:
368
+
369
+ return move, move_probs, simul_mean_time
370
  else:
371
+ return move, simul_mean_time
372
  else:
373
 
374
  if return_prob:
375
+
376
  return move, move_probs
377
  else:
378
  return move
Gomoku_MCTS/policy_value_net_pytorch_new.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ An implementation of the policyValueNet in PyTorch
4
+ Tested in PyTorch 0.2.0 and 0.3.0
5
+
6
+ @author: Junxiao Song
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import torch.nn.functional as F
13
+ from torch.autograd import Variable
14
+ import numpy as np
15
+
16
+
17
+ def set_learning_rate(optimizer, lr):
18
+ """Sets the learning rate to the given value"""
19
+ for param_group in optimizer.param_groups:
20
+ param_group['lr'] = lr
21
+
22
+
23
+
24
+
25
+ class ResidualBlock(nn.Module):
26
+ def __init__(self, channels):
27
+ super(ResidualBlock, self).__init__()
28
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
29
+ self.bn1 = nn.BatchNorm2d(channels)
30
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
31
+ self.bn2 = nn.BatchNorm2d(channels)
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+ out = F.relu(self.bn1(self.conv1(x)))
36
+ out = self.bn2(self.conv2(out))
37
+ out += residual
38
+ return F.relu(out)
39
+
40
+ class Net(nn.Module):
41
+ """Policy-Value network module for AlphaZero Gomoku."""
42
+ def __init__(self, board_width, board_height, num_residual_blocks=5):
43
+ super(Net, self).__init__()
44
+ self.board_width = board_width
45
+ self.board_height = board_height
46
+ self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
47
+ self.bn1 = nn.BatchNorm2d(32)
48
+ self.res_layers = nn.Sequential(*[ResidualBlock(32) for _ in range(num_residual_blocks)])
49
+
50
+ # Action Policy layers
51
+ self.act_conv1 = nn.Conv2d(32, 4, kernel_size=1)
52
+ self.act_fc1 = nn.Linear(4 * board_width * board_height, board_width * board_height)
53
+
54
+ # State Value layers
55
+ self.val_conv1 = nn.Conv2d(32, 2, kernel_size=1)
56
+ self.val_fc1 = nn.Linear(2 * board_width * board_height, 64)
57
+ self.val_fc2 = nn.Linear(64, 1)
58
+
59
+ def forward(self, x):
60
+ x = F.relu(self.bn1(self.conv1(x)))
61
+ x = self.res_layers(x)
62
+
63
+ # Action Policy head
64
+ x_act = F.relu(self.act_conv1(x))
65
+ x_act = x_act.view(-1, 4 * self.board_width * self.board_height)
66
+ x_act = F.log_softmax(self.act_fc1(x_act), dim=1)
67
+
68
+ # State Value head
69
+ x_val = F.relu(self.val_conv1(x))
70
+ x_val = x_val.view(-1, 2 * self.board_width * self.board_height)
71
+ x_val = F.relu(self.val_fc1(x_val))
72
+ x_val = torch.tanh(self.val_fc2(x_val))
73
+
74
+ return x_act, x_val
75
+
76
+
77
+ class PolicyValueNet():
78
+ """policy-value network """
79
+
80
+ def __init__(self, board_width, board_height,
81
+ model_file=None, use_gpu=False, bias = False):
82
+ self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
83
+ self.use_gpu = use_gpu
84
+ self.l2_const = 1e-4 # coef of l2 penalty
85
+ self.board_width = board_width
86
+ self.board_height = board_height
87
+ self.bias = bias
88
+
89
+ if model_file:
90
+ net_params = torch.load(model_file, map_location='cpu' if not use_gpu else None)
91
+
92
+ # Infer board dimensions from the loaded model
93
+ inferred_width, inferred_height = self.infer_board_size_from_model(net_params)
94
+ if inferred_width and inferred_height:
95
+ self.policy_value_net = Net(inferred_width, inferred_height).to(self.device) if use_gpu else Net(
96
+ inferred_width, inferred_height)
97
+ self.policy_value_net.load_state_dict(net_params)
98
+ print("Use model file to initialize the policy value net")
99
+ else:
100
+ raise Exception("The model file does not contain the board dimensions")
101
+
102
+ if inferred_width < board_width:
103
+ self.use_conv = True
104
+ elif inferred_width > board_width:
105
+ raise Exception("The model file has a larger board size than the current board size!!")
106
+ else:
107
+ # the policy value net module
108
+ if self.use_gpu:
109
+ self.policy_value_net = Net(board_width, board_height).to(self.device)
110
+ else:
111
+ self.policy_value_net = Net(board_width, board_height)
112
+
113
+ self.optimizer = optim.Adam(self.policy_value_net.parameters(),
114
+ weight_decay=self.l2_const)
115
+
116
+ def infer_board_size_from_model(self, model):
117
+ # Use the size of the act_fc1 layer to infer board dimensions
118
+ for name in model.keys():
119
+ if name == 'act_fc1.weight':
120
+ # Assuming the weight shape is [board_width * board_height, 4 * board_width * board_height]
121
+ c, _ = model[name].shape
122
+ print(f"act_fc1.weight shape: {model[name].shape}")
123
+ board_size = int(c ** 0.5) # Extracting board_width/height assuming they are the same
124
+ print(f"Board size inferred from model: {board_size}x{board_size}")
125
+ return board_size, board_size
126
+ return None
127
+
128
+ def apply_normal_bias(self, tensor, mean=0, std=1):
129
+ bsize = tensor.shape[0]
130
+ x, y = np.meshgrid(np.linspace(-1, 1, bsize), np.linspace(-1, 1, bsize))
131
+ d = np.sqrt(x * x + y * y)
132
+ sigma, mu = 1.0, 0.0
133
+ gauss = np.exp(-((d - mu) ** 2 / (2.0 * sigma ** 2)))
134
+ # Applying the bias only to non-zero elements
135
+ biased_tensor = tensor - (tensor != 0) * gauss
136
+ return biased_tensor
137
+
138
+ def policy_value(self, state_batch):
139
+ """
140
+ input: a batch of states
141
+ output: a batch of action probabilities and state values
142
+ """
143
+ if self.use_gpu:
144
+ state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
145
+ log_act_probs, value = self.policy_value_net(state_batch)
146
+ act_probs = np.exp(log_act_probs.data.cpu().numpy())
147
+ return act_probs, value.data.cpu().numpy()
148
+ else:
149
+ state_batch = Variable(torch.FloatTensor(state_batch))
150
+ log_act_probs, value = self.policy_value_net(state_batch)
151
+ act_probs = np.exp(log_act_probs.data.numpy())
152
+ return act_probs, value.data.numpy()
153
+
154
+ def policy_value_fn(self, board):
155
+ """
156
+ input: board
157
+ output: a list of (action, probability) tuples for each available
158
+ action and the score of the board state
159
+ """
160
+ legal_positions = board.availables
161
+ current_state = np.ascontiguousarray(board.current_state().reshape(
162
+ -1, 4, self.board_width, self.board_height))
163
+ if self.bias:
164
+ current_state[0][1] = self.apply_normal_bias(current_state[0][1])
165
+
166
+ if self.use_gpu:
167
+ log_act_probs, value = self.policy_value_net(
168
+ Variable(torch.from_numpy(current_state)).to(self.device).float())
169
+ act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
170
+ else:
171
+ log_act_probs, value = self.policy_value_net(
172
+ Variable(torch.from_numpy(current_state)).float())
173
+ act_probs = np.exp(log_act_probs.data.numpy().flatten())
174
+ act_probs = zip(legal_positions, act_probs[legal_positions])
175
+ value = value.data[0][0]
176
+ return act_probs, value
177
+
178
+ def train_step(self, state_batch, mcts_probs, winner_batch, lr):
179
+ """perform a training step"""
180
+
181
+ # self.use_gpu = True
182
+ # wrap in Variable
183
+ if self.use_gpu:
184
+ state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
185
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(self.device))
186
+ winner_batch = Variable(torch.FloatTensor(winner_batch).to(self.device))
187
+ else:
188
+ state_batch = Variable(torch.FloatTensor(state_batch))
189
+ mcts_probs = Variable(torch.FloatTensor(mcts_probs))
190
+ winner_batch = Variable(torch.FloatTensor(winner_batch))
191
+
192
+ # zero the parameter gradients
193
+ self.optimizer.zero_grad()
194
+ # set learning rate
195
+ set_learning_rate(self.optimizer, lr)
196
+
197
+ # forward
198
+ log_act_probs, value = self.policy_value_net(state_batch)
199
+ # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
200
+ # Note: the L2 penalty is incorporated in optimizer
201
+ value_loss = F.mse_loss(value.view(-1), winner_batch)
202
+ policy_loss = -torch.mean(torch.sum(mcts_probs * log_act_probs, 1))
203
+ loss = value_loss + policy_loss
204
+ # backward and optimize
205
+ loss.backward()
206
+ self.optimizer.step()
207
+ # calc policy entropy, for monitoring only
208
+ entropy = -torch.mean(
209
+ torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
210
+ )
211
+
212
+ # for pytorch version >= 0.5 please use the following line instead.
213
+ return loss.item(), entropy.item()
214
+
215
+ def get_policy_param(self):
216
+ net_params = self.policy_value_net.state_dict()
217
+ return net_params
218
+
219
+ def save_model(self, model_file):
220
+ """ save model params to file """
221
+ net_params = self.get_policy_param() # get model params
222
+ torch.save(net_params, model_file)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ import torch
227
+ import torch.onnx
228
+
229
+ # 假设您的 Net 模型已经定义好了
230
+ model = Net(board_width=9, board_height=9) # 使用适当的参数初始化模型
231
+ dummy_input = torch.randn(1, 4, 9, 9) # 创建一个示例输入
232
+
233
+ # 将模型导出到 ONNX 格式
234
+ torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
const.py CHANGED
@@ -9,7 +9,7 @@ import numpy as np
9
 
10
  _AI_AID_INFO = ["Use AI Aid", "Close AI Aid"]
11
 
12
- _BOARD_SIZE = 8
13
  _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
14
  _BLANK = 0
15
  _BLACK = 1
@@ -68,3 +68,10 @@ _ROOM_COLOR = {
68
  True: _BLACK,
69
  False: _WHITE,
70
  }
 
 
 
 
 
 
 
 
9
 
10
  _AI_AID_INFO = ["Use AI Aid", "Close AI Aid"]
11
 
12
+ _BOARD_SIZE = 9
13
  _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
14
  _BLANK = 0
15
  _BLACK = 1
 
68
  True: _BLACK,
69
  False: _WHITE,
70
  }
71
+
72
+
73
+ _MODEL_PATH = {
74
+ "AlphaZero": "/Users/husky/GomokuDemo/Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model",
75
+ "duel": "/Users/husky/GomokuDemo/Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/best_policy.model",
76
+ "Gumbel AlphaZero": "/Users/husky/GomokuDemo/Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model",
77
+ }
pages/Player_VS_AI.py CHANGED
@@ -15,13 +15,14 @@ import numpy as np
15
  import streamlit as st
16
  from scipy.signal import convolve # this is used to check if any player wins
17
  from streamlit import session_state
 
18
  from streamlit_server_state import server_state, server_state_lock
19
- from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet
 
20
  from Gomoku_Bot import Gomoku_bot
21
  from Gomoku_Bot import Board as Gomoku_bot_board
22
- import matplotlib.pyplot as plt
23
-
24
 
 
25
 
26
  from const import (
27
  _BLACK, # 1, for human
@@ -37,10 +38,10 @@ from const import (
37
  _DIAGONAL_UP_RIGHT,
38
  _BOARD_SIZE,
39
  _BOARD_SIZE_1D,
40
- _AI_AID_INFO
 
41
  )
42
 
43
-
44
  from ai import (
45
  BOS_TOKEN_ID,
46
  generate_gpt2,
@@ -63,14 +64,23 @@ class Room:
63
  self.TIME = time.time()
64
  self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
65
  self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
66
- 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
 
 
 
 
 
 
 
 
67
  'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
68
  self.MCTS = self.MCTS_dict['AlphaZero']
69
  self.last_mcts = self.MCTS
70
  self.AID_MCTS = self.MCTS_dict['AlphaZero']
71
  self.COORDINATE_1D = [BOS_TOKEN_ID]
72
  self.current_move = -1
73
- self.simula_time_list = []
 
74
 
75
 
76
  def change_turn(cur):
@@ -90,9 +100,9 @@ if "ROOMS" not in server_state:
90
  with server_state_lock["ROOMS"]:
91
  server_state.ROOMS = {}
92
 
 
93
  def handle_oppo_model_selection():
94
  if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
95
- session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
96
  session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
97
  return
98
  else:
@@ -100,20 +110,22 @@ def handle_oppo_model_selection():
100
  new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
101
  new_mct.mcts._root = deepcopy(TreeNode)
102
  session_state.ROOM.MCTS = new_mct
103
- session_state.ROOM.last_mcts = new_mct
104
  return
105
 
 
106
  def handle_aid_model_selection():
107
  if st.session_state['selected_aid_model'] == 'None':
108
  session_state.USE_AIAID = False
109
  return
110
  session_state.USE_AIAID = True
111
- TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
112
  new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
113
  new_mct.mcts._root = deepcopy(TreeNode)
114
  session_state.ROOM.AID_MCTS = new_mct
115
  return
116
 
 
117
  if 'selected_oppo_model' not in st.session_state:
118
  st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
119
 
@@ -125,7 +137,9 @@ TITLE = st.empty()
125
  Model_Switch = st.empty()
126
 
127
  TITLE.header("🤖 AI 3603 Gomoku")
128
- selected_oppo_option = Model_Switch.selectbox('Select Opponent Model', ['Pure MCTS', 'AlphaZero','Gomoku Bot'], index=1, key='oppo_model')
 
 
129
 
130
  if st.session_state['selected_oppo_model'] != selected_oppo_option:
131
  st.session_state['selected_oppo_model'] = selected_oppo_option
@@ -149,9 +163,11 @@ MULTIPLAYER_TAG = st.sidebar.empty()
149
  with st.sidebar.container():
150
  ANOTHER_ROUND = st.empty()
151
  RESTART = st.empty()
 
152
  AIAID = st.empty()
153
  EXIT = st.empty()
154
- selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0, key='aid_model')
 
155
  if st.session_state['selected_aid_model'] != selected_aid_option:
156
  st.session_state['selected_aid_model'] = selected_aid_option
157
  handle_aid_model_selection()
@@ -174,7 +190,6 @@ GAME_INFO.markdown(
174
  )
175
 
176
 
177
-
178
  def restart() -> None:
179
  """
180
  Restart the game.
@@ -182,12 +197,56 @@ def restart() -> None:
182
  session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
183
  st.session_state['selected_oppo_model'] = 'AlphaZero'
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  RESTART.button(
186
  "Reset",
187
  on_click=restart,
188
  help="Clear the board as well as the scores",
189
  )
190
 
 
 
 
 
 
 
191
 
192
  # Draw the board
193
  def gomoku():
@@ -207,13 +266,24 @@ def gomoku():
207
  session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
208
  session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
209
  session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
210
- 'AlphaZero': alphazero(PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE, 'Gomoku_MCTS/checkpoints/best_policy_8_8_5_2torch.pth').policy_value_fn, c_puct=5, n_playout=100),
211
- 'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
 
 
 
 
 
 
 
 
 
212
  session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
213
  session_state.ROOM.last_mcts = session_state.ROOM.MCTS
214
  session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
215
  session_state.ROOM.TURN = session_state.ROOM.PLAYER
216
  session_state.ROOM.WINNER = _BLANK # 0
 
 
217
  session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
218
 
219
  # Room status sync
@@ -310,7 +380,8 @@ def gomoku():
310
  session_state.ROOM.current_move = move
311
  session_state.ROOM.BOARD.do_move(move)
312
  # Gomoku Bot BOARD
313
- session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
 
314
  session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
315
  session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
316
 
@@ -356,7 +427,7 @@ def gomoku():
356
  _PLAYER_SYMBOL[_NEW],
357
  key=f"{i}:{j}",
358
  args=(i, j),
359
- on_click=handle_click,
360
  )
361
  else:
362
  # disable click for GPT choices
@@ -424,7 +495,7 @@ def gomoku():
424
  move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
425
  else:
426
  move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
427
- session_state.ROOM.simula_time_list.append(simul_time)
428
  print("AI takes move: ", move)
429
  session_state.ROOM.current_move = move
430
  gpt_response = move
@@ -436,7 +507,8 @@ def gomoku():
436
  # MCTS BOARD
437
  session_state.ROOM.BOARD.do_move(move)
438
  # Gomoku Bot BOARD
439
- session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE, move % _BOARD_SIZE)
 
440
  # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
441
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
442
 
@@ -475,7 +547,8 @@ def gomoku():
475
  on_click=forbid_click
476
  )
477
  else:
478
- if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not session_state.ROOM.BOARD.game_end()[0]:
 
479
  # enable click for other cells available for human choices
480
  prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
481
  BOARD_PLATE[i][j].button(
@@ -493,7 +566,6 @@ def gomoku():
493
  args=(i, j),
494
  )
495
 
496
-
497
  message.markdown(
498
  'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
499
  simul_time),
@@ -533,6 +605,7 @@ def gomoku():
533
  else:
534
  draw_board(True)
535
  if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
 
536
  ANOTHER_ROUND.button(
537
  "Play Next round!",
538
  on_click=another_round,
@@ -560,13 +633,12 @@ def gomoku():
560
  # draw the plot for simulation time
561
  # 创建一个 DataFrame
562
 
563
- # print(session_state.ROOM.simula_time_list)
564
  st.markdown("<br>", unsafe_allow_html=True)
565
  st.markdown("<br>", unsafe_allow_html=True)
566
- chart_data = pd.DataFrame(session_state.ROOM.simula_time_list, columns=["Simulation Time"])
567
  st.line_chart(chart_data)
568
 
569
-
570
  game_control()
571
  update_info()
572
 
 
15
  import streamlit as st
16
  from scipy.signal import convolve # this is used to check if any player wins
17
  from streamlit import session_state
18
+ from streamlit.delta_generator import DeltaGenerator
19
  from streamlit_server_state import server_state, server_state_lock
20
+ from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \
21
+ Gumbel_MCTSPlayer
22
  from Gomoku_Bot import Gomoku_bot
23
  from Gomoku_Bot import Board as Gomoku_bot_board
 
 
24
 
25
+ import matplotlib.pyplot as plt
26
 
27
  from const import (
28
  _BLACK, # 1, for human
 
38
  _DIAGONAL_UP_RIGHT,
39
  _BOARD_SIZE,
40
  _BOARD_SIZE_1D,
41
+ _AI_AID_INFO,
42
+ _MODEL_PATH
43
  )
44
 
 
45
  from ai import (
46
  BOS_TOKEN_ID,
47
  generate_gpt2,
 
64
  self.TIME = time.time()
65
  self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
66
  self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
67
+ 'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
68
+ _MODEL_PATH["AlphaZero"]).policy_value_fn,
69
+ c_puct=5, n_playout=100),
70
+ 'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
71
+ _MODEL_PATH["duel"]).policy_value_fn,
72
+ c_puct=5, n_playout=100),
73
+ 'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
74
+ _MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
75
+ c_puct=5, n_playout=100, m_action=8),
76
  'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
77
  self.MCTS = self.MCTS_dict['AlphaZero']
78
  self.last_mcts = self.MCTS
79
  self.AID_MCTS = self.MCTS_dict['AlphaZero']
80
  self.COORDINATE_1D = [BOS_TOKEN_ID]
81
  self.current_move = -1
82
+ self.ai_simula_time_list = []
83
+ self.human_simula_time_list = []
84
 
85
 
86
  def change_turn(cur):
 
100
  with server_state_lock["ROOMS"]:
101
  server_state.ROOMS = {}
102
 
103
+
104
  def handle_oppo_model_selection():
105
  if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
 
106
  session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
107
  return
108
  else:
 
110
  new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
111
  new_mct.mcts._root = deepcopy(TreeNode)
112
  session_state.ROOM.MCTS = new_mct
113
+ session_state.ROOM.last_mcts = new_mct
114
  return
115
 
116
+
117
  def handle_aid_model_selection():
118
  if st.session_state['selected_aid_model'] == 'None':
119
  session_state.USE_AIAID = False
120
  return
121
  session_state.USE_AIAID = True
122
+ TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
123
  new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
124
  new_mct.mcts._root = deepcopy(TreeNode)
125
  session_state.ROOM.AID_MCTS = new_mct
126
  return
127
 
128
+
129
  if 'selected_oppo_model' not in st.session_state:
130
  st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
131
 
 
137
  Model_Switch = st.empty()
138
 
139
  TITLE.header("🤖 AI 3603 Gomoku")
140
+ selected_oppo_option = Model_Switch.selectbox('Select Opponent Model',
141
+ ['Pure MCTS', 'AlphaZero', 'Gomoku Bot', 'duel', 'Gumbel AlphaZero'],
142
+ index=1, key='oppo_model')
143
 
144
  if st.session_state['selected_oppo_model'] != selected_oppo_option:
145
  st.session_state['selected_oppo_model'] = selected_oppo_option
 
163
  with st.sidebar.container():
164
  ANOTHER_ROUND = st.empty()
165
  RESTART = st.empty()
166
+ GIVEIN = st.empty()
167
  AIAID = st.empty()
168
  EXIT = st.empty()
169
+ selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
170
+ key='aid_model')
171
  if st.session_state['selected_aid_model'] != selected_aid_option:
172
  st.session_state['selected_aid_model'] = selected_aid_option
173
  handle_aid_model_selection()
 
190
  )
191
 
192
 
 
193
  def restart() -> None:
194
  """
195
  Restart the game.
 
197
  session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
198
  st.session_state['selected_oppo_model'] = 'AlphaZero'
199
 
200
+ def givein() -> None:
201
+ """
202
+ Give in to AI.
203
+ """
204
+ session_state.ROOM = deepcopy(session_state.ROOM)
205
+ session_state.ROOM.WINNER = _WHITE
206
+ # add 1 score to AI
207
+ session_state.ROOM.HISTORY = (
208
+ session_state.ROOM.HISTORY[0]
209
+ + int(session_state.ROOM.WINNER == _WHITE),
210
+ session_state.ROOM.HISTORY[1]
211
+ + int(session_state.ROOM.WINNER == _BLACK),
212
+ )
213
+ session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
214
+ session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
215
+ session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
216
+ 'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
217
+ _MODEL_PATH["AlphaZero"]).policy_value_fn,
218
+ c_puct=5, n_playout=100),
219
+ 'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
220
+ _MODEL_PATH["duel"]).policy_value_fn,
221
+ c_puct=5, n_playout=100),
222
+ 'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
223
+ _MODEL_PATH[
224
+ "Gumbel AlphaZero"]).policy_value_fn,
225
+ c_puct=5, n_playout=100, m_action=8),
226
+ 'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
227
+
228
+ session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
229
+ session_state.ROOM.last_mcts = session_state.ROOM.MCTS
230
+ session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
231
+ session_state.ROOM.TURN = session_state.ROOM.PLAYER
232
+ session_state.ROOM.WINNER = _BLANK # 0
233
+ session_state.ROOM.ai_simula_time_list = []
234
+ session_state.ROOM.human_simula_time_list = []
235
+ session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
236
+
237
+
238
  RESTART.button(
239
  "Reset",
240
  on_click=restart,
241
  help="Clear the board as well as the scores",
242
  )
243
 
244
+ GIVEIN.button(
245
+ "Give in",
246
+ on_click = givein,
247
+ help="Give in to AI",
248
+ )
249
+
250
 
251
  # Draw the board
252
  def gomoku():
 
266
  session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
267
  session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
268
  session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
269
+ 'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
270
+ _MODEL_PATH["AlphaZero"]).policy_value_fn,
271
+ c_puct=5, n_playout=100),
272
+ 'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
273
+ _MODEL_PATH["duel"]).policy_value_fn,
274
+ c_puct=5, n_playout=100),
275
+ 'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
276
+ _MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
277
+ c_puct=5, n_playout=100, m_action=8),
278
+ 'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
279
+
280
  session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
281
  session_state.ROOM.last_mcts = session_state.ROOM.MCTS
282
  session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
283
  session_state.ROOM.TURN = session_state.ROOM.PLAYER
284
  session_state.ROOM.WINNER = _BLANK # 0
285
+ session_state.ROOM.ai_simula_time_list = []
286
+ session_state.ROOM.human_simula_time_list = []
287
  session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
288
 
289
  # Room status sync
 
380
  session_state.ROOM.current_move = move
381
  session_state.ROOM.BOARD.do_move(move)
382
  # Gomoku Bot BOARD
383
+ session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - move // _BOARD_SIZE - 1,
384
+ move % _BOARD_SIZE) # # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0)
385
  session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
386
  session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
387
 
 
427
  _PLAYER_SYMBOL[_NEW],
428
  key=f"{i}:{j}",
429
  args=(i, j),
430
+ on_click=forbid_click,
431
  )
432
  else:
433
  # disable click for GPT choices
 
495
  move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
496
  else:
497
  move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
498
+ session_state.ROOM.ai_simula_time_list.append(simul_time)
499
  print("AI takes move: ", move)
500
  session_state.ROOM.current_move = move
501
  gpt_response = move
 
507
  # MCTS BOARD
508
  session_state.ROOM.BOARD.do_move(move)
509
  # Gomoku Bot BOARD
510
+ session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - 1 - move // _BOARD_SIZE,
511
+ move % _BOARD_SIZE)
512
  # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
513
  session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
514
 
 
547
  on_click=forbid_click
548
  )
549
  else:
550
+ if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not \
551
+ session_state.ROOM.BOARD.game_end()[0]:
552
  # enable click for other cells available for human choices
553
  prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
554
  BOARD_PLATE[i][j].button(
 
566
  args=(i, j),
567
  )
568
 
 
569
  message.markdown(
570
  'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
571
  simul_time),
 
605
  else:
606
  draw_board(True)
607
  if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
608
+ GIVEIN.empty()
609
  ANOTHER_ROUND.button(
610
  "Play Next round!",
611
  on_click=another_round,
 
633
  # draw the plot for simulation time
634
  # 创建一个 DataFrame
635
 
636
+ # print(session_state.ROOM.ai_simula_time_list)
637
  st.markdown("<br>", unsafe_allow_html=True)
638
  st.markdown("<br>", unsafe_allow_html=True)
639
+ chart_data = pd.DataFrame(session_state.ROOM.ai_simula_time_list, columns=["Simulation Time"])
640
  st.line_chart(chart_data)
641
 
 
642
  game_control()
643
  update_info()
644