Spaces:
Sleeping
Sleeping
newest
Browse files- Gomoku_Bot/eval.py +0 -39
- Gomoku_Bot/gomoku_bot.py +3 -2
- Gomoku_MCTS/__init__.py +4 -2
- Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/best_policy.model +3 -0
- Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/current_policy.model +3 -0
- Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model +3 -0
- Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/current_policy.model +3 -0
- Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model +3 -0
- Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/current_policy.model +3 -0
- Gomoku_MCTS/config/utils.py +12 -12
- Gomoku_MCTS/dueling_net.py +13 -10
- Gomoku_MCTS/mcts_Gumbel_Alphazero.py +96 -103
- Gomoku_MCTS/policy_value_net_pytorch_new.py +234 -0
- const.py +8 -1
- pages/Player_VS_AI.py +96 -24
Gomoku_Bot/eval.py
CHANGED
@@ -460,11 +460,9 @@ class Evaluate:
|
|
460 |
for point in fives:
|
461 |
x = point // self.size
|
462 |
y = point % self.size
|
463 |
-
model_train_matrix[x][y] = max(FIVE, model_train_matrix[x][y])
|
464 |
for point in block_fives:
|
465 |
x = point // self.size
|
466 |
y = point % self.size
|
467 |
-
model_train_matrix[x][y] = max(BLOCK_FIVE, model_train_matrix[x][y])
|
468 |
|
469 |
return set(list(fives) + list(block_fives)), model_train_matrix
|
470 |
|
@@ -474,12 +472,10 @@ class Evaluate:
|
|
474 |
for point in fours:
|
475 |
x = point // self.size
|
476 |
y = point % self.size
|
477 |
-
model_train_matrix[x][y] = max(FOUR, model_train_matrix[x][y])
|
478 |
|
479 |
for point in block_fours:
|
480 |
x = point // self.size
|
481 |
y = point % self.size
|
482 |
-
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
483 |
|
484 |
return set(list(fours) + list(block_fours)), model_train_matrix
|
485 |
|
@@ -488,12 +484,10 @@ class Evaluate:
|
|
488 |
for point in four_fours:
|
489 |
x = point // self.size
|
490 |
y = point % self.size
|
491 |
-
model_train_matrix[x][y] = max(FOUR_FOUR, model_train_matrix[x][y])
|
492 |
|
493 |
for point in block_fours:
|
494 |
x = point // self.size
|
495 |
y = point % self.size
|
496 |
-
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
497 |
|
498 |
return set(list(four_fours) + list(block_fours)), model_train_matrix
|
499 |
|
@@ -504,17 +498,14 @@ class Evaluate:
|
|
504 |
for point in four_threes:
|
505 |
x = point // self.size
|
506 |
y = point % self.size
|
507 |
-
model_train_matrix[x][y] = max(FOUR_THREE, model_train_matrix[x][y])
|
508 |
|
509 |
for point in block_fours:
|
510 |
x = point // self.size
|
511 |
y = point % self.size
|
512 |
-
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
513 |
|
514 |
for point in threes:
|
515 |
x = point // self.size
|
516 |
y = point % self.size
|
517 |
-
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
518 |
|
519 |
return set(list(four_threes) + list(block_fours) + list(threes)), model_train_matrix
|
520 |
|
@@ -524,17 +515,14 @@ class Evaluate:
|
|
524 |
for point in three_threes:
|
525 |
x = point // self.size
|
526 |
y = point % self.size
|
527 |
-
model_train_matrix[x][y] = max(THREE_THREE, model_train_matrix[x][y])
|
528 |
|
529 |
for point in block_fours:
|
530 |
x = point // self.size
|
531 |
y = point % self.size
|
532 |
-
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
533 |
|
534 |
for point in threes:
|
535 |
x = point // self.size
|
536 |
y = point % self.size
|
537 |
-
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
538 |
|
539 |
return set(list(three_threes) + list(block_fours) + list(threes)), model_train_matrix
|
540 |
|
@@ -542,43 +530,16 @@ class Evaluate:
|
|
542 |
for point in threes:
|
543 |
x = point // self.size
|
544 |
y = point % self.size
|
545 |
-
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
546 |
|
547 |
for point in block_fours:
|
548 |
x = point // self.size
|
549 |
y = point % self.size
|
550 |
-
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
551 |
return set(list(block_fours) + list(threes)), model_train_matrix
|
552 |
|
553 |
block_threes = points[shapes['BLOCK_THREE']]
|
554 |
two_twos = points[shapes['TWO_TWO']]
|
555 |
twos = points[shapes['TWO']]
|
556 |
|
557 |
-
for point in block_threes:
|
558 |
-
x = point // self.size
|
559 |
-
y = point % self.size
|
560 |
-
model_train_matrix[x][y] = max(BLOCK_THREE, model_train_matrix[x][y])
|
561 |
-
|
562 |
-
for point in two_twos:
|
563 |
-
x = point // self.size
|
564 |
-
y = point % self.size
|
565 |
-
model_train_matrix[x][y] = max(TWO_TWO, model_train_matrix[x][y])
|
566 |
-
|
567 |
-
for point in twos:
|
568 |
-
x = point // self.size
|
569 |
-
y = point % self.size
|
570 |
-
model_train_matrix[x][y] = max(TWO, model_train_matrix[x][y])
|
571 |
-
|
572 |
-
for point in block_fours:
|
573 |
-
x = point // self.size
|
574 |
-
y = point % self.size
|
575 |
-
model_train_matrix[x][y] = max(BLOCK_FOUR, model_train_matrix[x][y])
|
576 |
-
|
577 |
-
for point in threes:
|
578 |
-
x = point // self.size
|
579 |
-
y = point % self.size
|
580 |
-
model_train_matrix[x][y] = max(THREE, model_train_matrix[x][y])
|
581 |
-
|
582 |
mid = list(block_fours) + list(threes) + list(block_threes) + list(two_twos) + list(twos)
|
583 |
res = set(mid[:5])
|
584 |
for i in range(len(model_train_matrix)):
|
|
|
460 |
for point in fives:
|
461 |
x = point // self.size
|
462 |
y = point % self.size
|
|
|
463 |
for point in block_fives:
|
464 |
x = point // self.size
|
465 |
y = point % self.size
|
|
|
466 |
|
467 |
return set(list(fives) + list(block_fives)), model_train_matrix
|
468 |
|
|
|
472 |
for point in fours:
|
473 |
x = point // self.size
|
474 |
y = point % self.size
|
|
|
475 |
|
476 |
for point in block_fours:
|
477 |
x = point // self.size
|
478 |
y = point % self.size
|
|
|
479 |
|
480 |
return set(list(fours) + list(block_fours)), model_train_matrix
|
481 |
|
|
|
484 |
for point in four_fours:
|
485 |
x = point // self.size
|
486 |
y = point % self.size
|
|
|
487 |
|
488 |
for point in block_fours:
|
489 |
x = point // self.size
|
490 |
y = point % self.size
|
|
|
491 |
|
492 |
return set(list(four_fours) + list(block_fours)), model_train_matrix
|
493 |
|
|
|
498 |
for point in four_threes:
|
499 |
x = point // self.size
|
500 |
y = point % self.size
|
|
|
501 |
|
502 |
for point in block_fours:
|
503 |
x = point // self.size
|
504 |
y = point % self.size
|
|
|
505 |
|
506 |
for point in threes:
|
507 |
x = point // self.size
|
508 |
y = point % self.size
|
|
|
509 |
|
510 |
return set(list(four_threes) + list(block_fours) + list(threes)), model_train_matrix
|
511 |
|
|
|
515 |
for point in three_threes:
|
516 |
x = point // self.size
|
517 |
y = point % self.size
|
|
|
518 |
|
519 |
for point in block_fours:
|
520 |
x = point // self.size
|
521 |
y = point % self.size
|
|
|
522 |
|
523 |
for point in threes:
|
524 |
x = point // self.size
|
525 |
y = point % self.size
|
|
|
526 |
|
527 |
return set(list(three_threes) + list(block_fours) + list(threes)), model_train_matrix
|
528 |
|
|
|
530 |
for point in threes:
|
531 |
x = point // self.size
|
532 |
y = point % self.size
|
|
|
533 |
|
534 |
for point in block_fours:
|
535 |
x = point // self.size
|
536 |
y = point % self.size
|
|
|
537 |
return set(list(block_fours) + list(threes)), model_train_matrix
|
538 |
|
539 |
block_threes = points[shapes['BLOCK_THREE']]
|
540 |
two_twos = points[shapes['TWO_TWO']]
|
541 |
twos = points[shapes['TWO']]
|
542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
mid = list(block_fours) + list(threes) + list(block_threes) + list(two_twos) + list(twos)
|
544 |
res = set(mid[:5])
|
545 |
for i in range(len(model_train_matrix)):
|
Gomoku_Bot/gomoku_bot.py
CHANGED
@@ -3,7 +3,7 @@ import time
|
|
3 |
|
4 |
|
5 |
class Gomoku_bot:
|
6 |
-
def __init__(self, board, role, depth=4, enableVCT=
|
7 |
self.board = board
|
8 |
self.role = role
|
9 |
self.depth = depth
|
@@ -14,7 +14,8 @@ class Gomoku_bot:
|
|
14 |
score = minmax(self.board, self.role, self.depth, self.enableVCT)
|
15 |
end = time.time()
|
16 |
sim_time = end - start
|
17 |
-
move = score[1]
|
|
|
18 |
# turn tuple into an int
|
19 |
move = move[0] * self.board.size + move[1]
|
20 |
if return_time:
|
|
|
3 |
|
4 |
|
5 |
class Gomoku_bot:
|
6 |
+
def __init__(self, board, role, depth=4, enableVCT=False):
|
7 |
self.board = board
|
8 |
self.role = role
|
9 |
self.depth = depth
|
|
|
14 |
score = minmax(self.board, self.role, self.depth, self.enableVCT)
|
15 |
end = time.time()
|
16 |
sim_time = end - start
|
17 |
+
move = score[1] # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0)
|
18 |
+
move = (self.board.size - 1 - move[0], move[1]) # convert the move to the game's coordinate
|
19 |
# turn tuple into an int
|
20 |
move = move[0] * self.board.size + move[1]
|
21 |
if return_time:
|
Gomoku_MCTS/__init__.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from .mcts_pure import MCTSPlayer as MCTSpure
|
2 |
from .mcts_alphaZero import MCTSPlayer as alphazero
|
3 |
-
|
4 |
-
from .
|
|
|
|
|
5 |
import numpy as np
|
6 |
|
7 |
|
|
|
1 |
from .mcts_pure import MCTSPlayer as MCTSpure
|
2 |
from .mcts_alphaZero import MCTSPlayer as alphazero
|
3 |
+
from .policy_value_net_pytorch import PolicyValueNet as PolicyValueNet_old
|
4 |
+
from .policy_value_net_pytorch_new import PolicyValueNet as PolicyValueNet_new
|
5 |
+
from .dueling_net import PolicyValueNet as duel_PolicyValueNet
|
6 |
+
from .mcts_Gumbel_Alphazero import Gumbel_MCTSPlayer
|
7 |
import numpy as np
|
8 |
|
9 |
|
Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/best_policy.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6278cd8f69e66f42a927df96e1fe3952a6e1e5a41e37f99f315b9bc3febd6d7a
|
3 |
+
size 529974
|
Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/current_policy.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab24465e4c52e038fda2bafae71550cb89c3f42f59d2ebf12b7a45c2c353eb33
|
3 |
+
size 530034
|
Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d43bc56e0bc86c5548d7857b6777b1971d2d14da6f344280b6f81be8595ac710
|
3 |
+
size 555837
|
Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/current_policy.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ab8cee4afa72bf29d73d0707ad1b92e1d0a24a009721c110c99b2ff5d2f866f
|
3 |
+
size 556110
|
Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b8a225a330990d289278d8fa2cbd8bda0a7ea541c3ffd7aac6327d4553ef8683
|
3 |
+
size 555837
|
Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/current_policy.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0dfdc692e55ba43f6ae714da93ee6182b9adcf8824c292bb597ef0d003c6d10b
|
3 |
+
size 556110
|
Gomoku_MCTS/config/utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os, shutil
|
2 |
import torch
|
3 |
-
from tensorboardX import SummaryWriter
|
4 |
-
from
|
5 |
import torch.distributed as dist
|
6 |
import time
|
7 |
|
@@ -42,13 +42,13 @@ def makedir(path):
|
|
42 |
os.makedirs(path, 0o777)
|
43 |
|
44 |
|
45 |
-
def visualizer():
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
1 |
import os, shutil
|
2 |
import torch
|
3 |
+
# from tensorboardX import SummaryWriter
|
4 |
+
from .options import *
|
5 |
import torch.distributed as dist
|
6 |
import time
|
7 |
|
|
|
42 |
os.makedirs(path, 0o777)
|
43 |
|
44 |
|
45 |
+
# def visualizer():
|
46 |
+
# if get_rank() == 0:
|
47 |
+
# # filewriter_path = config['visual_base']+opts.savepath+'/'
|
48 |
+
# save_path = make_path()
|
49 |
+
# filewriter_path = os.path.join(config['visual_base'], save_path)
|
50 |
+
# if opts.clear_visualizer and os.path.exists(filewriter_path): # 删掉以前的summary,以免重合
|
51 |
+
# shutil.rmtree(filewriter_path)
|
52 |
+
# makedir(filewriter_path)
|
53 |
+
# writer = SummaryWriter(filewriter_path, comment='visualizer')
|
54 |
+
# return writer
|
Gomoku_MCTS/dueling_net.py
CHANGED
@@ -52,16 +52,16 @@ class DuelingDQNNet(nn.Module):
|
|
52 |
return F.log_softmax(q_values, dim=1), val
|
53 |
|
54 |
class PolicyValueNet():
|
55 |
-
"""
|
56 |
def __init__(self, board_width, board_height,
|
57 |
-
model_file=None, use_gpu=False):
|
58 |
self.use_gpu = use_gpu
|
59 |
self.board_width = board_width
|
60 |
self.board_height = board_height
|
61 |
self.l2_const = 1e-4 # coef of l2 penalty
|
62 |
# the policy value net module
|
63 |
if self.use_gpu:
|
64 |
-
self.policy_value_net = DuelingDQNNet(board_width, board_height).
|
65 |
else:
|
66 |
self.policy_value_net = DuelingDQNNet(board_width, board_height)
|
67 |
self.optimizer = optim.Adam(self.policy_value_net.parameters(),
|
@@ -70,7 +70,6 @@ class PolicyValueNet():
|
|
70 |
if model_file:
|
71 |
net_params = torch.load(model_file)
|
72 |
self.policy_value_net.load_state_dict(net_params, strict=False)
|
73 |
-
print('loaded dueling model file')
|
74 |
|
75 |
def policy_value(self, state_batch):
|
76 |
"""
|
@@ -78,7 +77,7 @@ class PolicyValueNet():
|
|
78 |
output: a batch of action probabilities and state values
|
79 |
"""
|
80 |
if self.use_gpu:
|
81 |
-
state_batch = Variable(torch.FloatTensor(state_batch).
|
82 |
log_act_probs, value = self.policy_value_net(state_batch)
|
83 |
act_probs = np.exp(log_act_probs.data.cpu().numpy())
|
84 |
return act_probs, value.data.cpu().numpy()
|
@@ -97,16 +96,20 @@ class PolicyValueNet():
|
|
97 |
legal_positions = board.availables
|
98 |
current_state = np.ascontiguousarray(board.current_state().reshape(
|
99 |
-1, 4, self.board_width, self.board_height))
|
|
|
100 |
if self.use_gpu:
|
101 |
log_act_probs, value = self.policy_value_net(
|
102 |
-
Variable(torch.from_numpy(current_state)).
|
103 |
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
|
104 |
else:
|
105 |
log_act_probs, value = self.policy_value_net(
|
106 |
Variable(torch.from_numpy(current_state)).float())
|
107 |
act_probs = np.exp(log_act_probs.data.numpy().flatten())
|
|
|
108 |
act_probs = zip(legal_positions, act_probs[legal_positions])
|
109 |
-
|
|
|
|
|
110 |
return act_probs, value
|
111 |
|
112 |
def train_step(self, state_batch, mcts_probs, winner_batch, lr):
|
@@ -115,9 +118,9 @@ class PolicyValueNet():
|
|
115 |
# self.use_gpu = True
|
116 |
# wrap in Variable
|
117 |
if self.use_gpu:
|
118 |
-
state_batch = Variable(torch.FloatTensor(state_batch).
|
119 |
-
mcts_probs = Variable(torch.FloatTensor(mcts_probs).
|
120 |
-
winner_batch = Variable(torch.FloatTensor(winner_batch).
|
121 |
else:
|
122 |
state_batch = Variable(torch.FloatTensor(state_batch))
|
123 |
mcts_probs = Variable(torch.FloatTensor(mcts_probs))
|
|
|
52 |
return F.log_softmax(q_values, dim=1), val
|
53 |
|
54 |
class PolicyValueNet():
|
55 |
+
"""policy-value network """
|
56 |
def __init__(self, board_width, board_height,
|
57 |
+
model_file=None, use_gpu=False, device = None):
|
58 |
self.use_gpu = use_gpu
|
59 |
self.board_width = board_width
|
60 |
self.board_height = board_height
|
61 |
self.l2_const = 1e-4 # coef of l2 penalty
|
62 |
# the policy value net module
|
63 |
if self.use_gpu:
|
64 |
+
self.policy_value_net = DuelingDQNNet(board_width, board_height).to(device)
|
65 |
else:
|
66 |
self.policy_value_net = DuelingDQNNet(board_width, board_height)
|
67 |
self.optimizer = optim.Adam(self.policy_value_net.parameters(),
|
|
|
70 |
if model_file:
|
71 |
net_params = torch.load(model_file)
|
72 |
self.policy_value_net.load_state_dict(net_params, strict=False)
|
|
|
73 |
|
74 |
def policy_value(self, state_batch):
|
75 |
"""
|
|
|
77 |
output: a batch of action probabilities and state values
|
78 |
"""
|
79 |
if self.use_gpu:
|
80 |
+
state_batch = Variable(torch.FloatTensor(state_batch).to(device))
|
81 |
log_act_probs, value = self.policy_value_net(state_batch)
|
82 |
act_probs = np.exp(log_act_probs.data.cpu().numpy())
|
83 |
return act_probs, value.data.cpu().numpy()
|
|
|
96 |
legal_positions = board.availables
|
97 |
current_state = np.ascontiguousarray(board.current_state().reshape(
|
98 |
-1, 4, self.board_width, self.board_height))
|
99 |
+
|
100 |
if self.use_gpu:
|
101 |
log_act_probs, value = self.policy_value_net(
|
102 |
+
Variable(torch.from_numpy(current_state)).to(device).float())
|
103 |
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
|
104 |
else:
|
105 |
log_act_probs, value = self.policy_value_net(
|
106 |
Variable(torch.from_numpy(current_state)).float())
|
107 |
act_probs = np.exp(log_act_probs.data.numpy().flatten())
|
108 |
+
|
109 |
act_probs = zip(legal_positions, act_probs[legal_positions])
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
return act_probs, value
|
114 |
|
115 |
def train_step(self, state_batch, mcts_probs, winner_batch, lr):
|
|
|
118 |
# self.use_gpu = True
|
119 |
# wrap in Variable
|
120 |
if self.use_gpu:
|
121 |
+
state_batch = Variable(torch.FloatTensor(state_batch).to(device))
|
122 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(device))
|
123 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch).to(device))
|
124 |
else:
|
125 |
state_batch = Variable(torch.FloatTensor(state_batch))
|
126 |
mcts_probs = Variable(torch.FloatTensor(mcts_probs))
|
Gomoku_MCTS/mcts_Gumbel_Alphazero.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
FileName:
|
3 |
Author: Jiaxin Li
|
4 |
Create Date: 2023/11/21
|
5 |
Description: The implement of Gumbel MCST
|
@@ -9,11 +9,11 @@ Debug: the dim of output: probs
|
|
9 |
|
10 |
import numpy as np
|
11 |
import copy
|
12 |
-
import time
|
13 |
|
14 |
-
from config.options import *
|
15 |
import sys
|
16 |
-
from config.utils import *
|
17 |
|
18 |
|
19 |
def softmax(x):
|
@@ -22,8 +22,8 @@ def softmax(x):
|
|
22 |
return probs
|
23 |
|
24 |
|
25 |
-
def _sigma_mano(y
|
26 |
-
return (50 + Nb) * 1.0 * y
|
27 |
|
28 |
|
29 |
class TreeNode(object):
|
@@ -42,8 +42,6 @@ class TreeNode(object):
|
|
42 |
self._v = 0
|
43 |
self._p = prior_p
|
44 |
|
45 |
-
|
46 |
-
|
47 |
def expand(self, action_priors):
|
48 |
"""Expand tree by creating new children.
|
49 |
action_priors: a list of tuples of actions and their prior probability
|
@@ -52,7 +50,6 @@ class TreeNode(object):
|
|
52 |
for action, prob in action_priors:
|
53 |
if action not in self._children:
|
54 |
self._children[action] = TreeNode(self, prob)
|
55 |
-
|
56 |
|
57 |
def select(self, v_pi):
|
58 |
"""Select action among children that gives maximum
|
@@ -62,29 +59,25 @@ class TreeNode(object):
|
|
62 |
# if opts.split == "train":
|
63 |
# v_pi = v_pi.detach().numpy()
|
64 |
# print(v_pi)
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
max_N_b = np.max(np.array(
|
70 |
|
71 |
if opts.split == "train":
|
72 |
-
pi_ = softmax(
|
|
|
73 |
else:
|
74 |
-
pi_ = softmax(
|
|
|
75 |
# print(pi_.shape)
|
76 |
-
|
77 |
|
78 |
-
N_a = np.array(
|
|
|
79 |
# print(N_a.shape)
|
80 |
|
81 |
-
max_index=
|
82 |
# print((pi_ - N_a).shape)
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
return list(self._children.items())[max_index]
|
87 |
|
|
|
88 |
|
89 |
def update(self, leaf_value):
|
90 |
"""Update node values from leaf evaluation.
|
@@ -95,11 +88,9 @@ class TreeNode(object):
|
|
95 |
self._n_visits += 1
|
96 |
# Update Q, a running average of values for all visits.
|
97 |
if opts.split == "train":
|
98 |
-
self._Q = self._Q +
|
99 |
-
|
100 |
-
|
101 |
-
else:
|
102 |
-
self._Q += (1.0*(leaf_value - self._Q) / self._n_visits)
|
103 |
|
104 |
def update_recursive(self, leaf_value):
|
105 |
"""Like a call to update(), but applied recursively for all ancestors.
|
@@ -109,14 +100,13 @@ class TreeNode(object):
|
|
109 |
self._parent.update_recursive(-leaf_value)
|
110 |
self.update(leaf_value)
|
111 |
|
112 |
-
def get_pi(self,v_pi,max_N_b):
|
113 |
if self._n_visits == 0:
|
114 |
Q_completed = v_pi
|
115 |
else:
|
116 |
Q_completed = self._Q
|
117 |
-
|
118 |
-
return self._p + _sigma_mano(Q_completed,max_N_b)
|
119 |
|
|
|
120 |
|
121 |
def get_value(self, c_puct):
|
122 |
"""Calculate and return the value for this node.
|
@@ -155,9 +145,6 @@ class Gumbel_MCTS(object):
|
|
155 |
self._c_puct = c_puct
|
156 |
self._n_playout = n_playout
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
def Gumbel_playout(self, child_node, child_state):
|
162 |
"""Run a single playout from the child of the root to the leaf, getting a value at
|
163 |
the leaf and propagating it back through its parents.
|
@@ -166,28 +153,26 @@ class Gumbel_MCTS(object):
|
|
166 |
"""
|
167 |
node = child_node
|
168 |
state = child_state
|
169 |
-
|
170 |
-
while(1):
|
171 |
if node.is_leaf():
|
172 |
break
|
173 |
# Greedily select next move.
|
174 |
|
175 |
action, node = node.select(node._v)
|
176 |
-
|
177 |
-
state.do_move(action)
|
178 |
-
|
179 |
|
|
|
180 |
|
181 |
# Evaluate the leaf using a network which outputs a list of
|
182 |
# (action, probability) tuples p and also a score v in [-1, 1]
|
183 |
# for the current player.
|
|
|
184 |
action_probs, leaf_value = self._policy(state)
|
185 |
-
|
186 |
-
leaf_value = leaf_value.detach().numpy()[0][0]
|
187 |
|
188 |
-
|
189 |
-
|
190 |
|
|
|
191 |
|
192 |
# Check for end of game.
|
193 |
end, winner = state.game_end()
|
@@ -204,24 +189,19 @@ class Gumbel_MCTS(object):
|
|
204 |
|
205 |
# Update value and visit count of nodes in this traversal.
|
206 |
node.update_recursive(-leaf_value)
|
207 |
-
|
208 |
|
209 |
-
def top_k(self,x, k):
|
210 |
# print("x",x.shape)
|
211 |
# print("k ", k)
|
212 |
|
213 |
return np.argpartition(x, k)[..., -k:]
|
214 |
|
215 |
-
def sample_k(self,logits, k):
|
216 |
u = np.random.uniform(size=np.shape(logits))
|
217 |
z = -np.log(-np.log(u))
|
|
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
return self.top_k(logits + z, k),z
|
222 |
-
|
223 |
-
|
224 |
-
def get_move_probs(self, state, temp=1e-3,m_action = 16):
|
225 |
"""Run all playouts sequentially and return the available actions and
|
226 |
their corresponding probabilities.
|
227 |
state: the current game state
|
@@ -231,92 +211,102 @@ class Gumbel_MCTS(object):
|
|
231 |
# logits 暂定为 p
|
232 |
|
233 |
start_time = time.time()
|
234 |
-
|
235 |
-
|
236 |
# 对根节点进行拓展
|
|
|
237 |
act_probs, leaf_value = self._policy(state)
|
238 |
-
act_probs = list(act_probs)
|
239 |
|
240 |
-
|
241 |
-
|
|
|
|
|
242 |
# print(list(act_probs))
|
243 |
-
porbs = [prob
|
244 |
-
self._root.expand(act_probs)
|
245 |
|
|
|
246 |
|
247 |
n = self._n_playout
|
248 |
-
m = min(
|
249 |
-
|
250 |
|
251 |
# 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
|
252 |
-
A_topm
|
253 |
-
|
254 |
# 获得state选取每个action后对应的状态,保存到一个列表中
|
255 |
root_childs = list(self._root._children.items())
|
256 |
-
|
257 |
|
258 |
child_state_m = []
|
259 |
for i in range(m):
|
260 |
state_copy = copy.deepcopy(state)
|
261 |
-
action,node = root_childs[A_topm[i]]
|
262 |
state_copy.do_move(action)
|
263 |
child_state_m.append(state_copy)
|
|
|
|
|
|
|
|
|
264 |
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
268 |
|
269 |
# 进行sequential halving with Gumbel
|
270 |
while m >= 1:
|
271 |
-
|
272 |
# 对每个选择的动作进行仿真
|
273 |
for i in range(m):
|
274 |
action_state = child_state_m[i]
|
275 |
-
|
276 |
-
action,node = root_childs[A_topm[i]]
|
277 |
-
|
278 |
for j in range(N):
|
279 |
action_state_copy = copy.deepcopy(action_state)
|
280 |
-
|
281 |
# 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
|
282 |
self.Gumbel_playout(node, action_state_copy)
|
283 |
|
284 |
# 每轮不重复采样的动作个数减半
|
285 |
-
m = m //2
|
286 |
|
287 |
# 不是最后一轮,单轮仿真次数加倍
|
288 |
-
if(m != 1):
|
289 |
n = n - N
|
290 |
N *= 2
|
291 |
# 当最后一轮时,只有一个动作,把所有仿真次数用完
|
292 |
else:
|
293 |
N = n
|
294 |
-
|
295 |
# 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
|
296 |
# print([action_node[1]._Q for action_node in self._root._children.items() ])
|
297 |
-
|
298 |
-
|
299 |
-
q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items() ])
|
300 |
-
|
301 |
|
302 |
-
|
|
|
|
|
|
|
|
|
303 |
|
304 |
-
A_index = self.top_k(
|
305 |
A_topm = np.array(A_topm)[A_index]
|
306 |
child_state_m = np.array(child_state_m)[A_index]
|
307 |
-
|
308 |
-
|
309 |
# 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))
|
310 |
|
311 |
-
max_N_b = np.max(np.array(
|
312 |
|
313 |
-
final_act_probs=
|
314 |
-
|
315 |
|
|
|
|
|
|
|
|
|
|
|
316 |
need_time = time.time() - start_time
|
317 |
-
print(f" Gumbel Alphazero sum_time: {need_time
|
318 |
|
319 |
-
return
|
320 |
|
321 |
def update_with_move(self, last_move):
|
322 |
"""Step forward in the tree, keeping everything we already know
|
@@ -336,50 +326,53 @@ class Gumbel_MCTSPlayer(object):
|
|
336 |
"""AI player based on MCTS"""
|
337 |
|
338 |
def __init__(self, policy_value_function,
|
339 |
-
c_puct=5, n_playout=2000, is_selfplay=0,m_action
|
340 |
self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
|
341 |
self._is_selfplay = is_selfplay
|
342 |
self.m_action = m_action
|
343 |
|
344 |
-
|
345 |
def set_player_ind(self, p):
|
346 |
self.player = p
|
347 |
|
348 |
def reset_player(self):
|
349 |
self.mcts.update_with_move(-1)
|
350 |
|
351 |
-
|
352 |
-
def get_action(self, board, temp=1e-3, return_prob=0,return_time = False):
|
353 |
sensible_moves = board.availables
|
354 |
# the pi vector returned by MCTS as in the alphaGo Zero paper
|
355 |
-
move_probs = np.zeros(board.width*board.height)
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
if len(sensible_moves) > 0:
|
360 |
|
361 |
# 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
|
362 |
-
move, acts, probs,simul_mean_time
|
363 |
-
|
364 |
-
|
365 |
|
366 |
# 重置搜索树
|
367 |
self.mcts.update_with_move(-1)
|
368 |
|
369 |
move_probs[list(acts)] = probs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
|
|
|
|
371 |
|
372 |
if return_time:
|
373 |
|
374 |
if return_prob:
|
375 |
-
|
376 |
-
return move, move_probs,simul_mean_time
|
377 |
else:
|
378 |
-
return move,simul_mean_time
|
379 |
else:
|
380 |
|
381 |
if return_prob:
|
382 |
-
|
383 |
return move, move_probs
|
384 |
else:
|
385 |
return move
|
|
|
1 |
"""
|
2 |
+
FileName: mcts_Gumbel_Alphazero.py
|
3 |
Author: Jiaxin Li
|
4 |
Create Date: 2023/11/21
|
5 |
Description: The implement of Gumbel MCST
|
|
|
9 |
|
10 |
import numpy as np
|
11 |
import copy
|
12 |
+
import time
|
13 |
|
14 |
+
from .config.options import *
|
15 |
import sys
|
16 |
+
from .config.utils import *
|
17 |
|
18 |
|
19 |
def softmax(x):
|
|
|
22 |
return probs
|
23 |
|
24 |
|
25 |
+
def _sigma_mano(y, Nb):
|
26 |
+
return (50 + Nb) * 1.0 * y
|
27 |
|
28 |
|
29 |
class TreeNode(object):
|
|
|
42 |
self._v = 0
|
43 |
self._p = prior_p
|
44 |
|
|
|
|
|
45 |
def expand(self, action_priors):
|
46 |
"""Expand tree by creating new children.
|
47 |
action_priors: a list of tuples of actions and their prior probability
|
|
|
50 |
for action, prob in action_priors:
|
51 |
if action not in self._children:
|
52 |
self._children[action] = TreeNode(self, prob)
|
|
|
53 |
|
54 |
def select(self, v_pi):
|
55 |
"""Select action among children that gives maximum
|
|
|
59 |
# if opts.split == "train":
|
60 |
# v_pi = v_pi.detach().numpy()
|
61 |
# print(v_pi)
|
|
|
|
|
|
|
62 |
|
63 |
+
max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._children.items()]))
|
64 |
|
65 |
if opts.split == "train":
|
66 |
+
pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
|
67 |
+
len(list(self._children.items())), -1)
|
68 |
else:
|
69 |
+
pi_ = softmax(np.array([act_node[1].get_pi(v_pi, max_N_b) for act_node in self._children.items()])).reshape(
|
70 |
+
len(list(self._children.items())), -1)
|
71 |
# print(pi_.shape)
|
|
|
72 |
|
73 |
+
N_a = np.array([act_node[1]._n_visits / (1 + self._n_visits) for act_node in self._children.items()]).reshape(
|
74 |
+
pi_.shape[0], -1)
|
75 |
# print(N_a.shape)
|
76 |
|
77 |
+
max_index = np.argmax(pi_ - N_a)
|
78 |
# print((pi_ - N_a).shape)
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
return list(self._children.items())[max_index]
|
81 |
|
82 |
def update(self, leaf_value):
|
83 |
"""Update node values from leaf evaluation.
|
|
|
88 |
self._n_visits += 1
|
89 |
# Update Q, a running average of values for all visits.
|
90 |
if opts.split == "train":
|
91 |
+
self._Q = self._Q + (1.0 * (leaf_value - self._Q) / self._n_visits)
|
92 |
+
else:
|
93 |
+
self._Q += (1.0 * (leaf_value - self._Q) / self._n_visits)
|
|
|
|
|
94 |
|
95 |
def update_recursive(self, leaf_value):
|
96 |
"""Like a call to update(), but applied recursively for all ancestors.
|
|
|
100 |
self._parent.update_recursive(-leaf_value)
|
101 |
self.update(leaf_value)
|
102 |
|
103 |
+
def get_pi(self, v_pi, max_N_b):
|
104 |
if self._n_visits == 0:
|
105 |
Q_completed = v_pi
|
106 |
else:
|
107 |
Q_completed = self._Q
|
|
|
|
|
108 |
|
109 |
+
return self._p + _sigma_mano(Q_completed, max_N_b)
|
110 |
|
111 |
def get_value(self, c_puct):
|
112 |
"""Calculate and return the value for this node.
|
|
|
145 |
self._c_puct = c_puct
|
146 |
self._n_playout = n_playout
|
147 |
|
|
|
|
|
|
|
148 |
def Gumbel_playout(self, child_node, child_state):
|
149 |
"""Run a single playout from the child of the root to the leaf, getting a value at
|
150 |
the leaf and propagating it back through its parents.
|
|
|
153 |
"""
|
154 |
node = child_node
|
155 |
state = child_state
|
156 |
+
|
157 |
+
while (1):
|
158 |
if node.is_leaf():
|
159 |
break
|
160 |
# Greedily select next move.
|
161 |
|
162 |
action, node = node.select(node._v)
|
|
|
|
|
|
|
163 |
|
164 |
+
state.do_move(action)
|
165 |
|
166 |
# Evaluate the leaf using a network which outputs a list of
|
167 |
# (action, probability) tuples p and also a score v in [-1, 1]
|
168 |
# for the current player.
|
169 |
+
|
170 |
action_probs, leaf_value = self._policy(state)
|
|
|
|
|
171 |
|
172 |
+
# leaf_value = leaf_value.detach().numpy()[0][0]
|
173 |
+
leaf_value = leaf_value.detach().numpy()
|
174 |
|
175 |
+
node._v = leaf_value
|
176 |
|
177 |
# Check for end of game.
|
178 |
end, winner = state.game_end()
|
|
|
189 |
|
190 |
# Update value and visit count of nodes in this traversal.
|
191 |
node.update_recursive(-leaf_value)
|
|
|
192 |
|
193 |
+
def top_k(self, x, k):
|
194 |
# print("x",x.shape)
|
195 |
# print("k ", k)
|
196 |
|
197 |
return np.argpartition(x, k)[..., -k:]
|
198 |
|
199 |
+
def sample_k(self, logits, k):
|
200 |
u = np.random.uniform(size=np.shape(logits))
|
201 |
z = -np.log(-np.log(u))
|
202 |
+
return self.top_k(logits + z, k), z
|
203 |
|
204 |
+
def get_move_probs(self, state, temp=1e-3, m_action=16):
|
|
|
|
|
|
|
|
|
|
|
205 |
"""Run all playouts sequentially and return the available actions and
|
206 |
their corresponding probabilities.
|
207 |
state: the current game state
|
|
|
211 |
# logits 暂定为 p
|
212 |
|
213 |
start_time = time.time()
|
|
|
|
|
214 |
# 对根节点进行拓展
|
215 |
+
|
216 |
act_probs, leaf_value = self._policy(state)
|
|
|
217 |
|
218 |
+
act_probs = list(act_probs)
|
219 |
+
|
220 |
+
# leaf_value = leaf_value.detach().numpy()[0][0]
|
221 |
+
leaf_value = leaf_value.detach().numpy()
|
222 |
# print(list(act_probs))
|
223 |
+
porbs = [prob for act, prob in (act_probs)]
|
|
|
224 |
|
225 |
+
self._root.expand(act_probs)
|
226 |
|
227 |
n = self._n_playout
|
228 |
+
m = min(m_action, int(len(porbs) / 2))
|
|
|
229 |
|
230 |
# 先进行Gumbel 分布采样,不重复的采样前m个动作,对应选择公式 logits + g
|
231 |
+
A_topm, g = self.sample_k(porbs, m)
|
232 |
+
|
233 |
# 获得state选取每个action后对应的状态,保存到一个列表中
|
234 |
root_childs = list(self._root._children.items())
|
|
|
235 |
|
236 |
child_state_m = []
|
237 |
for i in range(m):
|
238 |
state_copy = copy.deepcopy(state)
|
239 |
+
action, node = root_childs[A_topm[i]]
|
240 |
state_copy.do_move(action)
|
241 |
child_state_m.append(state_copy)
|
242 |
+
print(porbs)
|
243 |
+
|
244 |
+
print("depend on:", np.array(porbs)[A_topm])
|
245 |
+
print(f"A_topm_{m}", A_topm)
|
246 |
|
247 |
+
print("m ", m)
|
248 |
+
|
249 |
+
if m > 1:
|
250 |
+
# 每轮对选择的动作进行的仿真次数
|
251 |
+
N = int(n / (np.log(m) * m))
|
252 |
+
else:
|
253 |
+
N = n
|
254 |
|
255 |
# 进行sequential halving with Gumbel
|
256 |
while m >= 1:
|
257 |
+
|
258 |
# 对每个选择的动作进行仿真
|
259 |
for i in range(m):
|
260 |
action_state = child_state_m[i]
|
261 |
+
|
262 |
+
action, node = root_childs[A_topm[i]]
|
263 |
+
|
264 |
for j in range(N):
|
265 |
action_state_copy = copy.deepcopy(action_state)
|
266 |
+
|
267 |
# 对选择动作进行仿真: 即找到这个子树的叶节点,然后再网络中预测v,然后往上回溯的过程
|
268 |
self.Gumbel_playout(node, action_state_copy)
|
269 |
|
270 |
# 每轮不重复采样的动作个数减半
|
271 |
+
m = m // 2
|
272 |
|
273 |
# 不是最后一轮,单轮仿真次数加倍
|
274 |
+
if (m != 1):
|
275 |
n = n - N
|
276 |
N *= 2
|
277 |
# 当最后一轮时,只有一个动作,把所有仿真次数用完
|
278 |
else:
|
279 |
N = n
|
280 |
+
|
281 |
# 进行新的一轮不重复采样, 采样在之前的动作前一半的动作, 对应公式 g + logits + \sigma( \hat{q} )
|
282 |
# print([action_node[1]._Q for action_node in self._root._children.items() ])
|
|
|
|
|
|
|
|
|
283 |
|
284 |
+
q_hat = np.array([action_node[1]._Q for action_node in self._root._children.items()])
|
285 |
+
assert (np.sum(q_hat[A_topm] == 0) == 0)
|
286 |
+
|
287 |
+
print("depend on:", np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm])
|
288 |
+
print(f"A_topm_{m}", A_topm)
|
289 |
|
290 |
+
A_index = self.top_k(np.array(porbs)[A_topm] + np.array(g)[A_topm] + q_hat[A_topm], m)
|
291 |
A_topm = np.array(A_topm)[A_index]
|
292 |
child_state_m = np.array(child_state_m)[A_index]
|
293 |
+
|
|
|
294 |
# 最后返回对应的决策函数, 即 pi' = softmax(logits + sigma(completed Q))
|
295 |
|
296 |
+
max_N_b = np.max(np.array([act_node[1]._n_visits for act_node in self._root._children.items()]))
|
297 |
|
298 |
+
final_act_probs = softmax(
|
299 |
+
np.array([act_node[1].get_pi(leaf_value, max_N_b) for act_node in self._root._children.items()]))
|
300 |
|
301 |
+
action = (np.array([act_node[0] for act_node in self._root._children.items()]))
|
302 |
+
print("final_act_prbs", final_act_probs)
|
303 |
+
print("move :", action)
|
304 |
+
print("final_action", np.array(list(self._root._children.items()))[A_topm][0][0])
|
305 |
+
print("argmax_prob", np.argmax(final_act_probs))
|
306 |
need_time = time.time() - start_time
|
307 |
+
print(f" Gumbel Alphazero sum_time: {need_time}, total_simulation: {self._n_playout}")
|
308 |
|
309 |
+
return np.array(list(self._root._children.items()))[A_topm][0][0], action, final_act_probs, need_time
|
310 |
|
311 |
def update_with_move(self, last_move):
|
312 |
"""Step forward in the tree, keeping everything we already know
|
|
|
326 |
"""AI player based on MCTS"""
|
327 |
|
328 |
def __init__(self, policy_value_function,
|
329 |
+
c_puct=5, n_playout=2000, is_selfplay=0, m_action=16):
|
330 |
self.mcts = Gumbel_MCTS(policy_value_function, c_puct, n_playout)
|
331 |
self._is_selfplay = is_selfplay
|
332 |
self.m_action = m_action
|
333 |
|
|
|
334 |
def set_player_ind(self, p):
|
335 |
self.player = p
|
336 |
|
337 |
def reset_player(self):
|
338 |
self.mcts.update_with_move(-1)
|
339 |
|
340 |
+
def get_action(self, board, temp=1e-3, return_prob=0, return_time=False):
|
|
|
341 |
sensible_moves = board.availables
|
342 |
# the pi vector returned by MCTS as in the alphaGo Zero paper
|
343 |
+
move_probs = np.zeros(board.width * board.height)
|
344 |
+
|
|
|
|
|
345 |
if len(sensible_moves) > 0:
|
346 |
|
347 |
# 在搜索树中利用sequential halving with Gumbel 来进行动作选择 并且返回对应的决策函数
|
348 |
+
move, acts, probs, simul_mean_time = self.mcts.get_move_probs(board, temp, self.m_action)
|
|
|
|
|
349 |
|
350 |
# 重置搜索树
|
351 |
self.mcts.update_with_move(-1)
|
352 |
|
353 |
move_probs[list(acts)] = probs
|
354 |
+
move_probs = np.zeros(move_probs.shape[0])
|
355 |
+
move_probs[move] = 1
|
356 |
+
|
357 |
+
print("final prob:", move_probs)
|
358 |
+
print("arg_max:", np.argmax(move_probs))
|
359 |
+
print("max", np.max(move_probs))
|
360 |
+
print("move", move)
|
361 |
|
362 |
+
# 他通过训练能够使得最后move_probs 有一个位置趋近于1,即得到一个策略
|
363 |
+
# 关键是他的策略,和MCTS得到move不一致,怀疑是分布策略计算的问题
|
364 |
|
365 |
if return_time:
|
366 |
|
367 |
if return_prob:
|
368 |
+
|
369 |
+
return move, move_probs, simul_mean_time
|
370 |
else:
|
371 |
+
return move, simul_mean_time
|
372 |
else:
|
373 |
|
374 |
if return_prob:
|
375 |
+
|
376 |
return move, move_probs
|
377 |
else:
|
378 |
return move
|
Gomoku_MCTS/policy_value_net_pytorch_new.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
An implementation of the policyValueNet in PyTorch
|
4 |
+
Tested in PyTorch 0.2.0 and 0.3.0
|
5 |
+
|
6 |
+
@author: Junxiao Song
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.autograd import Variable
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
def set_learning_rate(optimizer, lr):
|
18 |
+
"""Sets the learning rate to the given value"""
|
19 |
+
for param_group in optimizer.param_groups:
|
20 |
+
param_group['lr'] = lr
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class ResidualBlock(nn.Module):
|
26 |
+
def __init__(self, channels):
|
27 |
+
super(ResidualBlock, self).__init__()
|
28 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
29 |
+
self.bn1 = nn.BatchNorm2d(channels)
|
30 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
31 |
+
self.bn2 = nn.BatchNorm2d(channels)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
residual = x
|
35 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
36 |
+
out = self.bn2(self.conv2(out))
|
37 |
+
out += residual
|
38 |
+
return F.relu(out)
|
39 |
+
|
40 |
+
class Net(nn.Module):
|
41 |
+
"""Policy-Value network module for AlphaZero Gomoku."""
|
42 |
+
def __init__(self, board_width, board_height, num_residual_blocks=5):
|
43 |
+
super(Net, self).__init__()
|
44 |
+
self.board_width = board_width
|
45 |
+
self.board_height = board_height
|
46 |
+
self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
|
47 |
+
self.bn1 = nn.BatchNorm2d(32)
|
48 |
+
self.res_layers = nn.Sequential(*[ResidualBlock(32) for _ in range(num_residual_blocks)])
|
49 |
+
|
50 |
+
# Action Policy layers
|
51 |
+
self.act_conv1 = nn.Conv2d(32, 4, kernel_size=1)
|
52 |
+
self.act_fc1 = nn.Linear(4 * board_width * board_height, board_width * board_height)
|
53 |
+
|
54 |
+
# State Value layers
|
55 |
+
self.val_conv1 = nn.Conv2d(32, 2, kernel_size=1)
|
56 |
+
self.val_fc1 = nn.Linear(2 * board_width * board_height, 64)
|
57 |
+
self.val_fc2 = nn.Linear(64, 1)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
61 |
+
x = self.res_layers(x)
|
62 |
+
|
63 |
+
# Action Policy head
|
64 |
+
x_act = F.relu(self.act_conv1(x))
|
65 |
+
x_act = x_act.view(-1, 4 * self.board_width * self.board_height)
|
66 |
+
x_act = F.log_softmax(self.act_fc1(x_act), dim=1)
|
67 |
+
|
68 |
+
# State Value head
|
69 |
+
x_val = F.relu(self.val_conv1(x))
|
70 |
+
x_val = x_val.view(-1, 2 * self.board_width * self.board_height)
|
71 |
+
x_val = F.relu(self.val_fc1(x_val))
|
72 |
+
x_val = torch.tanh(self.val_fc2(x_val))
|
73 |
+
|
74 |
+
return x_act, x_val
|
75 |
+
|
76 |
+
|
77 |
+
class PolicyValueNet():
|
78 |
+
"""policy-value network """
|
79 |
+
|
80 |
+
def __init__(self, board_width, board_height,
|
81 |
+
model_file=None, use_gpu=False, bias = False):
|
82 |
+
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
83 |
+
self.use_gpu = use_gpu
|
84 |
+
self.l2_const = 1e-4 # coef of l2 penalty
|
85 |
+
self.board_width = board_width
|
86 |
+
self.board_height = board_height
|
87 |
+
self.bias = bias
|
88 |
+
|
89 |
+
if model_file:
|
90 |
+
net_params = torch.load(model_file, map_location='cpu' if not use_gpu else None)
|
91 |
+
|
92 |
+
# Infer board dimensions from the loaded model
|
93 |
+
inferred_width, inferred_height = self.infer_board_size_from_model(net_params)
|
94 |
+
if inferred_width and inferred_height:
|
95 |
+
self.policy_value_net = Net(inferred_width, inferred_height).to(self.device) if use_gpu else Net(
|
96 |
+
inferred_width, inferred_height)
|
97 |
+
self.policy_value_net.load_state_dict(net_params)
|
98 |
+
print("Use model file to initialize the policy value net")
|
99 |
+
else:
|
100 |
+
raise Exception("The model file does not contain the board dimensions")
|
101 |
+
|
102 |
+
if inferred_width < board_width:
|
103 |
+
self.use_conv = True
|
104 |
+
elif inferred_width > board_width:
|
105 |
+
raise Exception("The model file has a larger board size than the current board size!!")
|
106 |
+
else:
|
107 |
+
# the policy value net module
|
108 |
+
if self.use_gpu:
|
109 |
+
self.policy_value_net = Net(board_width, board_height).to(self.device)
|
110 |
+
else:
|
111 |
+
self.policy_value_net = Net(board_width, board_height)
|
112 |
+
|
113 |
+
self.optimizer = optim.Adam(self.policy_value_net.parameters(),
|
114 |
+
weight_decay=self.l2_const)
|
115 |
+
|
116 |
+
def infer_board_size_from_model(self, model):
|
117 |
+
# Use the size of the act_fc1 layer to infer board dimensions
|
118 |
+
for name in model.keys():
|
119 |
+
if name == 'act_fc1.weight':
|
120 |
+
# Assuming the weight shape is [board_width * board_height, 4 * board_width * board_height]
|
121 |
+
c, _ = model[name].shape
|
122 |
+
print(f"act_fc1.weight shape: {model[name].shape}")
|
123 |
+
board_size = int(c ** 0.5) # Extracting board_width/height assuming they are the same
|
124 |
+
print(f"Board size inferred from model: {board_size}x{board_size}")
|
125 |
+
return board_size, board_size
|
126 |
+
return None
|
127 |
+
|
128 |
+
def apply_normal_bias(self, tensor, mean=0, std=1):
|
129 |
+
bsize = tensor.shape[0]
|
130 |
+
x, y = np.meshgrid(np.linspace(-1, 1, bsize), np.linspace(-1, 1, bsize))
|
131 |
+
d = np.sqrt(x * x + y * y)
|
132 |
+
sigma, mu = 1.0, 0.0
|
133 |
+
gauss = np.exp(-((d - mu) ** 2 / (2.0 * sigma ** 2)))
|
134 |
+
# Applying the bias only to non-zero elements
|
135 |
+
biased_tensor = tensor - (tensor != 0) * gauss
|
136 |
+
return biased_tensor
|
137 |
+
|
138 |
+
def policy_value(self, state_batch):
|
139 |
+
"""
|
140 |
+
input: a batch of states
|
141 |
+
output: a batch of action probabilities and state values
|
142 |
+
"""
|
143 |
+
if self.use_gpu:
|
144 |
+
state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
|
145 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
146 |
+
act_probs = np.exp(log_act_probs.data.cpu().numpy())
|
147 |
+
return act_probs, value.data.cpu().numpy()
|
148 |
+
else:
|
149 |
+
state_batch = Variable(torch.FloatTensor(state_batch))
|
150 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
151 |
+
act_probs = np.exp(log_act_probs.data.numpy())
|
152 |
+
return act_probs, value.data.numpy()
|
153 |
+
|
154 |
+
def policy_value_fn(self, board):
|
155 |
+
"""
|
156 |
+
input: board
|
157 |
+
output: a list of (action, probability) tuples for each available
|
158 |
+
action and the score of the board state
|
159 |
+
"""
|
160 |
+
legal_positions = board.availables
|
161 |
+
current_state = np.ascontiguousarray(board.current_state().reshape(
|
162 |
+
-1, 4, self.board_width, self.board_height))
|
163 |
+
if self.bias:
|
164 |
+
current_state[0][1] = self.apply_normal_bias(current_state[0][1])
|
165 |
+
|
166 |
+
if self.use_gpu:
|
167 |
+
log_act_probs, value = self.policy_value_net(
|
168 |
+
Variable(torch.from_numpy(current_state)).to(self.device).float())
|
169 |
+
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
|
170 |
+
else:
|
171 |
+
log_act_probs, value = self.policy_value_net(
|
172 |
+
Variable(torch.from_numpy(current_state)).float())
|
173 |
+
act_probs = np.exp(log_act_probs.data.numpy().flatten())
|
174 |
+
act_probs = zip(legal_positions, act_probs[legal_positions])
|
175 |
+
value = value.data[0][0]
|
176 |
+
return act_probs, value
|
177 |
+
|
178 |
+
def train_step(self, state_batch, mcts_probs, winner_batch, lr):
|
179 |
+
"""perform a training step"""
|
180 |
+
|
181 |
+
# self.use_gpu = True
|
182 |
+
# wrap in Variable
|
183 |
+
if self.use_gpu:
|
184 |
+
state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
|
185 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(self.device))
|
186 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch).to(self.device))
|
187 |
+
else:
|
188 |
+
state_batch = Variable(torch.FloatTensor(state_batch))
|
189 |
+
mcts_probs = Variable(torch.FloatTensor(mcts_probs))
|
190 |
+
winner_batch = Variable(torch.FloatTensor(winner_batch))
|
191 |
+
|
192 |
+
# zero the parameter gradients
|
193 |
+
self.optimizer.zero_grad()
|
194 |
+
# set learning rate
|
195 |
+
set_learning_rate(self.optimizer, lr)
|
196 |
+
|
197 |
+
# forward
|
198 |
+
log_act_probs, value = self.policy_value_net(state_batch)
|
199 |
+
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
|
200 |
+
# Note: the L2 penalty is incorporated in optimizer
|
201 |
+
value_loss = F.mse_loss(value.view(-1), winner_batch)
|
202 |
+
policy_loss = -torch.mean(torch.sum(mcts_probs * log_act_probs, 1))
|
203 |
+
loss = value_loss + policy_loss
|
204 |
+
# backward and optimize
|
205 |
+
loss.backward()
|
206 |
+
self.optimizer.step()
|
207 |
+
# calc policy entropy, for monitoring only
|
208 |
+
entropy = -torch.mean(
|
209 |
+
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
|
210 |
+
)
|
211 |
+
|
212 |
+
# for pytorch version >= 0.5 please use the following line instead.
|
213 |
+
return loss.item(), entropy.item()
|
214 |
+
|
215 |
+
def get_policy_param(self):
|
216 |
+
net_params = self.policy_value_net.state_dict()
|
217 |
+
return net_params
|
218 |
+
|
219 |
+
def save_model(self, model_file):
|
220 |
+
""" save model params to file """
|
221 |
+
net_params = self.get_policy_param() # get model params
|
222 |
+
torch.save(net_params, model_file)
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == "__main__":
|
226 |
+
import torch
|
227 |
+
import torch.onnx
|
228 |
+
|
229 |
+
# 假设您的 Net 模型已经定义好了
|
230 |
+
model = Net(board_width=9, board_height=9) # 使用适当的参数初始化模型
|
231 |
+
dummy_input = torch.randn(1, 4, 9, 9) # 创建一个示例输入
|
232 |
+
|
233 |
+
# 将模型导出到 ONNX 格式
|
234 |
+
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)
|
const.py
CHANGED
@@ -9,7 +9,7 @@ import numpy as np
|
|
9 |
|
10 |
_AI_AID_INFO = ["Use AI Aid", "Close AI Aid"]
|
11 |
|
12 |
-
_BOARD_SIZE =
|
13 |
_BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
14 |
_BLANK = 0
|
15 |
_BLACK = 1
|
@@ -68,3 +68,10 @@ _ROOM_COLOR = {
|
|
68 |
True: _BLACK,
|
69 |
False: _WHITE,
|
70 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
_AI_AID_INFO = ["Use AI Aid", "Close AI Aid"]
|
11 |
|
12 |
+
_BOARD_SIZE = 9
|
13 |
_BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
14 |
_BLANK = 0
|
15 |
_BLACK = 1
|
|
|
68 |
True: _BLACK,
|
69 |
False: _WHITE,
|
70 |
}
|
71 |
+
|
72 |
+
|
73 |
+
_MODEL_PATH = {
|
74 |
+
"AlphaZero": "/Users/husky/GomokuDemo/Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model",
|
75 |
+
"duel": "/Users/husky/GomokuDemo/Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/best_policy.model",
|
76 |
+
"Gumbel AlphaZero": "/Users/husky/GomokuDemo/Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model",
|
77 |
+
}
|
pages/Player_VS_AI.py
CHANGED
@@ -15,13 +15,14 @@ import numpy as np
|
|
15 |
import streamlit as st
|
16 |
from scipy.signal import convolve # this is used to check if any player wins
|
17 |
from streamlit import session_state
|
|
|
18 |
from streamlit_server_state import server_state, server_state_lock
|
19 |
-
from Gomoku_MCTS import MCTSpure, alphazero, Board,
|
|
|
20 |
from Gomoku_Bot import Gomoku_bot
|
21 |
from Gomoku_Bot import Board as Gomoku_bot_board
|
22 |
-
import matplotlib.pyplot as plt
|
23 |
-
|
24 |
|
|
|
25 |
|
26 |
from const import (
|
27 |
_BLACK, # 1, for human
|
@@ -37,10 +38,10 @@ from const import (
|
|
37 |
_DIAGONAL_UP_RIGHT,
|
38 |
_BOARD_SIZE,
|
39 |
_BOARD_SIZE_1D,
|
40 |
-
_AI_AID_INFO
|
|
|
41 |
)
|
42 |
|
43 |
-
|
44 |
from ai import (
|
45 |
BOS_TOKEN_ID,
|
46 |
generate_gpt2,
|
@@ -63,14 +64,23 @@ class Room:
|
|
63 |
self.TIME = time.time()
|
64 |
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
65 |
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
66 |
-
'AlphaZero': alphazero(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
68 |
self.MCTS = self.MCTS_dict['AlphaZero']
|
69 |
self.last_mcts = self.MCTS
|
70 |
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
71 |
self.COORDINATE_1D = [BOS_TOKEN_ID]
|
72 |
self.current_move = -1
|
73 |
-
self.
|
|
|
74 |
|
75 |
|
76 |
def change_turn(cur):
|
@@ -90,9 +100,9 @@ if "ROOMS" not in server_state:
|
|
90 |
with server_state_lock["ROOMS"]:
|
91 |
server_state.ROOMS = {}
|
92 |
|
|
|
93 |
def handle_oppo_model_selection():
|
94 |
if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
|
95 |
-
session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
|
96 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
|
97 |
return
|
98 |
else:
|
@@ -100,20 +110,22 @@ def handle_oppo_model_selection():
|
|
100 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
101 |
new_mct.mcts._root = deepcopy(TreeNode)
|
102 |
session_state.ROOM.MCTS = new_mct
|
103 |
-
session_state.ROOM.last_mcts
|
104 |
return
|
105 |
|
|
|
106 |
def handle_aid_model_selection():
|
107 |
if st.session_state['selected_aid_model'] == 'None':
|
108 |
session_state.USE_AIAID = False
|
109 |
return
|
110 |
session_state.USE_AIAID = True
|
111 |
-
TreeNode = session_state.ROOM.MCTS.mcts._root
|
112 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
113 |
new_mct.mcts._root = deepcopy(TreeNode)
|
114 |
session_state.ROOM.AID_MCTS = new_mct
|
115 |
return
|
116 |
|
|
|
117 |
if 'selected_oppo_model' not in st.session_state:
|
118 |
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
|
119 |
|
@@ -125,7 +137,9 @@ TITLE = st.empty()
|
|
125 |
Model_Switch = st.empty()
|
126 |
|
127 |
TITLE.header("🤖 AI 3603 Gomoku")
|
128 |
-
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model',
|
|
|
|
|
129 |
|
130 |
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
131 |
st.session_state['selected_oppo_model'] = selected_oppo_option
|
@@ -149,9 +163,11 @@ MULTIPLAYER_TAG = st.sidebar.empty()
|
|
149 |
with st.sidebar.container():
|
150 |
ANOTHER_ROUND = st.empty()
|
151 |
RESTART = st.empty()
|
|
|
152 |
AIAID = st.empty()
|
153 |
EXIT = st.empty()
|
154 |
-
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
|
|
155 |
if st.session_state['selected_aid_model'] != selected_aid_option:
|
156 |
st.session_state['selected_aid_model'] = selected_aid_option
|
157 |
handle_aid_model_selection()
|
@@ -174,7 +190,6 @@ GAME_INFO.markdown(
|
|
174 |
)
|
175 |
|
176 |
|
177 |
-
|
178 |
def restart() -> None:
|
179 |
"""
|
180 |
Restart the game.
|
@@ -182,12 +197,56 @@ def restart() -> None:
|
|
182 |
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
183 |
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
RESTART.button(
|
186 |
"Reset",
|
187 |
on_click=restart,
|
188 |
help="Clear the board as well as the scores",
|
189 |
)
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
# Draw the board
|
193 |
def gomoku():
|
@@ -207,13 +266,24 @@ def gomoku():
|
|
207 |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
208 |
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
209 |
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
213 |
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
214 |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
215 |
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
216 |
session_state.ROOM.WINNER = _BLANK # 0
|
|
|
|
|
217 |
session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
|
218 |
|
219 |
# Room status sync
|
@@ -310,7 +380,8 @@ def gomoku():
|
|
310 |
session_state.ROOM.current_move = move
|
311 |
session_state.ROOM.BOARD.do_move(move)
|
312 |
# Gomoku Bot BOARD
|
313 |
-
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE
|
|
|
314 |
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
315 |
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
316 |
|
@@ -356,7 +427,7 @@ def gomoku():
|
|
356 |
_PLAYER_SYMBOL[_NEW],
|
357 |
key=f"{i}:{j}",
|
358 |
args=(i, j),
|
359 |
-
on_click=
|
360 |
)
|
361 |
else:
|
362 |
# disable click for GPT choices
|
@@ -424,7 +495,7 @@ def gomoku():
|
|
424 |
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
425 |
else:
|
426 |
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
427 |
-
session_state.ROOM.
|
428 |
print("AI takes move: ", move)
|
429 |
session_state.ROOM.current_move = move
|
430 |
gpt_response = move
|
@@ -436,7 +507,8 @@ def gomoku():
|
|
436 |
# MCTS BOARD
|
437 |
session_state.ROOM.BOARD.do_move(move)
|
438 |
# Gomoku Bot BOARD
|
439 |
-
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(
|
|
|
440 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
441 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
442 |
|
@@ -475,7 +547,8 @@ def gomoku():
|
|
475 |
on_click=forbid_click
|
476 |
)
|
477 |
else:
|
478 |
-
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not
|
|
|
479 |
# enable click for other cells available for human choices
|
480 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
481 |
BOARD_PLATE[i][j].button(
|
@@ -493,7 +566,6 @@ def gomoku():
|
|
493 |
args=(i, j),
|
494 |
)
|
495 |
|
496 |
-
|
497 |
message.markdown(
|
498 |
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
499 |
simul_time),
|
@@ -533,6 +605,7 @@ def gomoku():
|
|
533 |
else:
|
534 |
draw_board(True)
|
535 |
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
|
|
|
536 |
ANOTHER_ROUND.button(
|
537 |
"Play Next round!",
|
538 |
on_click=another_round,
|
@@ -560,13 +633,12 @@ def gomoku():
|
|
560 |
# draw the plot for simulation time
|
561 |
# 创建一个 DataFrame
|
562 |
|
563 |
-
# print(session_state.ROOM.
|
564 |
st.markdown("<br>", unsafe_allow_html=True)
|
565 |
st.markdown("<br>", unsafe_allow_html=True)
|
566 |
-
chart_data = pd.DataFrame(session_state.ROOM.
|
567 |
st.line_chart(chart_data)
|
568 |
|
569 |
-
|
570 |
game_control()
|
571 |
update_info()
|
572 |
|
|
|
15 |
import streamlit as st
|
16 |
from scipy.signal import convolve # this is used to check if any player wins
|
17 |
from streamlit import session_state
|
18 |
+
from streamlit.delta_generator import DeltaGenerator
|
19 |
from streamlit_server_state import server_state, server_state_lock
|
20 |
+
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \
|
21 |
+
Gumbel_MCTSPlayer
|
22 |
from Gomoku_Bot import Gomoku_bot
|
23 |
from Gomoku_Bot import Board as Gomoku_bot_board
|
|
|
|
|
24 |
|
25 |
+
import matplotlib.pyplot as plt
|
26 |
|
27 |
from const import (
|
28 |
_BLACK, # 1, for human
|
|
|
38 |
_DIAGONAL_UP_RIGHT,
|
39 |
_BOARD_SIZE,
|
40 |
_BOARD_SIZE_1D,
|
41 |
+
_AI_AID_INFO,
|
42 |
+
_MODEL_PATH
|
43 |
)
|
44 |
|
|
|
45 |
from ai import (
|
46 |
BOS_TOKEN_ID,
|
47 |
generate_gpt2,
|
|
|
64 |
self.TIME = time.time()
|
65 |
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
66 |
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
67 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
68 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
69 |
+
c_puct=5, n_playout=100),
|
70 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
71 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
72 |
+
c_puct=5, n_playout=100),
|
73 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
74 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
75 |
+
c_puct=5, n_playout=100, m_action=8),
|
76 |
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
77 |
self.MCTS = self.MCTS_dict['AlphaZero']
|
78 |
self.last_mcts = self.MCTS
|
79 |
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
80 |
self.COORDINATE_1D = [BOS_TOKEN_ID]
|
81 |
self.current_move = -1
|
82 |
+
self.ai_simula_time_list = []
|
83 |
+
self.human_simula_time_list = []
|
84 |
|
85 |
|
86 |
def change_turn(cur):
|
|
|
100 |
with server_state_lock["ROOMS"]:
|
101 |
server_state.ROOMS = {}
|
102 |
|
103 |
+
|
104 |
def handle_oppo_model_selection():
|
105 |
if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
|
|
|
106 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
|
107 |
return
|
108 |
else:
|
|
|
110 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
111 |
new_mct.mcts._root = deepcopy(TreeNode)
|
112 |
session_state.ROOM.MCTS = new_mct
|
113 |
+
session_state.ROOM.last_mcts = new_mct
|
114 |
return
|
115 |
|
116 |
+
|
117 |
def handle_aid_model_selection():
|
118 |
if st.session_state['selected_aid_model'] == 'None':
|
119 |
session_state.USE_AIAID = False
|
120 |
return
|
121 |
session_state.USE_AIAID = True
|
122 |
+
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
|
123 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
124 |
new_mct.mcts._root = deepcopy(TreeNode)
|
125 |
session_state.ROOM.AID_MCTS = new_mct
|
126 |
return
|
127 |
|
128 |
+
|
129 |
if 'selected_oppo_model' not in st.session_state:
|
130 |
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
|
131 |
|
|
|
137 |
Model_Switch = st.empty()
|
138 |
|
139 |
TITLE.header("🤖 AI 3603 Gomoku")
|
140 |
+
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model',
|
141 |
+
['Pure MCTS', 'AlphaZero', 'Gomoku Bot', 'duel', 'Gumbel AlphaZero'],
|
142 |
+
index=1, key='oppo_model')
|
143 |
|
144 |
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
145 |
st.session_state['selected_oppo_model'] = selected_oppo_option
|
|
|
163 |
with st.sidebar.container():
|
164 |
ANOTHER_ROUND = st.empty()
|
165 |
RESTART = st.empty()
|
166 |
+
GIVEIN = st.empty()
|
167 |
AIAID = st.empty()
|
168 |
EXIT = st.empty()
|
169 |
+
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
170 |
+
key='aid_model')
|
171 |
if st.session_state['selected_aid_model'] != selected_aid_option:
|
172 |
st.session_state['selected_aid_model'] = selected_aid_option
|
173 |
handle_aid_model_selection()
|
|
|
190 |
)
|
191 |
|
192 |
|
|
|
193 |
def restart() -> None:
|
194 |
"""
|
195 |
Restart the game.
|
|
|
197 |
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
198 |
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
199 |
|
200 |
+
def givein() -> None:
|
201 |
+
"""
|
202 |
+
Give in to AI.
|
203 |
+
"""
|
204 |
+
session_state.ROOM = deepcopy(session_state.ROOM)
|
205 |
+
session_state.ROOM.WINNER = _WHITE
|
206 |
+
# add 1 score to AI
|
207 |
+
session_state.ROOM.HISTORY = (
|
208 |
+
session_state.ROOM.HISTORY[0]
|
209 |
+
+ int(session_state.ROOM.WINNER == _WHITE),
|
210 |
+
session_state.ROOM.HISTORY[1]
|
211 |
+
+ int(session_state.ROOM.WINNER == _BLACK),
|
212 |
+
)
|
213 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
214 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
215 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
216 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
217 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
218 |
+
c_puct=5, n_playout=100),
|
219 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
220 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
221 |
+
c_puct=5, n_playout=100),
|
222 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
223 |
+
_MODEL_PATH[
|
224 |
+
"Gumbel AlphaZero"]).policy_value_fn,
|
225 |
+
c_puct=5, n_playout=100, m_action=8),
|
226 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
227 |
+
|
228 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
229 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
230 |
+
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
231 |
+
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
232 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
233 |
+
session_state.ROOM.ai_simula_time_list = []
|
234 |
+
session_state.ROOM.human_simula_time_list = []
|
235 |
+
session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
|
236 |
+
|
237 |
+
|
238 |
RESTART.button(
|
239 |
"Reset",
|
240 |
on_click=restart,
|
241 |
help="Clear the board as well as the scores",
|
242 |
)
|
243 |
|
244 |
+
GIVEIN.button(
|
245 |
+
"Give in",
|
246 |
+
on_click = givein,
|
247 |
+
help="Give in to AI",
|
248 |
+
)
|
249 |
+
|
250 |
|
251 |
# Draw the board
|
252 |
def gomoku():
|
|
|
266 |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
267 |
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
268 |
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
269 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
270 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
271 |
+
c_puct=5, n_playout=100),
|
272 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
273 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
274 |
+
c_puct=5, n_playout=100),
|
275 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
276 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
277 |
+
c_puct=5, n_playout=100, m_action=8),
|
278 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
279 |
+
|
280 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
281 |
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
282 |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
283 |
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
284 |
session_state.ROOM.WINNER = _BLANK # 0
|
285 |
+
session_state.ROOM.ai_simula_time_list = []
|
286 |
+
session_state.ROOM.human_simula_time_list = []
|
287 |
session_state.ROOM.COORDINATE_1D = [BOS_TOKEN_ID]
|
288 |
|
289 |
# Room status sync
|
|
|
380 |
session_state.ROOM.current_move = move
|
381 |
session_state.ROOM.BOARD.do_move(move)
|
382 |
# Gomoku Bot BOARD
|
383 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - move // _BOARD_SIZE - 1,
|
384 |
+
move % _BOARD_SIZE) # # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0)
|
385 |
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
386 |
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
387 |
|
|
|
427 |
_PLAYER_SYMBOL[_NEW],
|
428 |
key=f"{i}:{j}",
|
429 |
args=(i, j),
|
430 |
+
on_click=forbid_click,
|
431 |
)
|
432 |
else:
|
433 |
# disable click for GPT choices
|
|
|
495 |
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
496 |
else:
|
497 |
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
498 |
+
session_state.ROOM.ai_simula_time_list.append(simul_time)
|
499 |
print("AI takes move: ", move)
|
500 |
session_state.ROOM.current_move = move
|
501 |
gpt_response = move
|
|
|
507 |
# MCTS BOARD
|
508 |
session_state.ROOM.BOARD.do_move(move)
|
509 |
# Gomoku Bot BOARD
|
510 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - 1 - move // _BOARD_SIZE,
|
511 |
+
move % _BOARD_SIZE)
|
512 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
513 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
514 |
|
|
|
547 |
on_click=forbid_click
|
548 |
)
|
549 |
else:
|
550 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not \
|
551 |
+
session_state.ROOM.BOARD.game_end()[0]:
|
552 |
# enable click for other cells available for human choices
|
553 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
554 |
BOARD_PLATE[i][j].button(
|
|
|
566 |
args=(i, j),
|
567 |
)
|
568 |
|
|
|
569 |
message.markdown(
|
570 |
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
571 |
simul_time),
|
|
|
605 |
else:
|
606 |
draw_board(True)
|
607 |
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
|
608 |
+
GIVEIN.empty()
|
609 |
ANOTHER_ROUND.button(
|
610 |
"Play Next round!",
|
611 |
on_click=another_round,
|
|
|
633 |
# draw the plot for simulation time
|
634 |
# 创建一个 DataFrame
|
635 |
|
636 |
+
# print(session_state.ROOM.ai_simula_time_list)
|
637 |
st.markdown("<br>", unsafe_allow_html=True)
|
638 |
st.markdown("<br>", unsafe_allow_html=True)
|
639 |
+
chart_data = pd.DataFrame(session_state.ROOM.ai_simula_time_list, columns=["Simulation Time"])
|
640 |
st.line_chart(chart_data)
|
641 |
|
|
|
642 |
game_control()
|
643 |
update_info()
|
644 |
|