File size: 9,371 Bytes
beb9e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a106
beb9e09
 
d97a106
beb9e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a106
beb9e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a106
beb9e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# -*- coding: utf-8 -*-
"""
An implementation of the policyValueNet in PyTorch
Tested in PyTorch 0.2.0 and 0.3.0

@author: Junxiao Song
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np


def set_learning_rate(optimizer, lr):
    """Sets the learning rate to the given value"""
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)


class Net(nn.Module):
    """Policy-Value network module for AlphaZero Gomoku."""

    def __init__(self, board_width, board_height, num_residual_blocks=5):
        super(Net, self).__init__()
        self.board_width = board_width
        self.board_height = board_height
        self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.res_layers = nn.Sequential(*[ResidualBlock(32) for _ in range(num_residual_blocks)])

        # Action Policy layers
        self.act_conv1 = nn.Conv2d(32, 4, kernel_size=1)
        self.act_fc1 = nn.Linear(4 * board_width * board_height, board_width * board_height)

        # State Value layers
        self.val_conv1 = nn.Conv2d(32, 2, kernel_size=1)
        self.val_fc1 = nn.Linear(2 * board_width * board_height, 64)
        self.val_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.res_layers(x)

        # Action Policy head
        x_act = F.relu(self.act_conv1(x))
        x_act = x_act.view(-1, 4 * self.board_width * self.board_height)
        x_act = F.log_softmax(self.act_fc1(x_act), dim=1)

        # State Value head
        x_val = F.relu(self.val_conv1(x))
        x_val = x_val.view(-1, 2 * self.board_width * self.board_height)
        x_val = F.relu(self.val_fc1(x_val))
        x_val = torch.tanh(self.val_fc2(x_val))

        return x_act, x_val


class PolicyValueNet():
    """policy-value network """

    def __init__(self, board_width, board_height,
                 model_file=None, use_gpu=False, bias=False):
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        self.use_gpu = use_gpu
        self.l2_const = 1e-4  # coef of l2 penalty
        self.board_width = board_width
        self.board_height = board_height
        self.bias = bias

        if model_file:
            net_params = torch.load(model_file, map_location='cpu' if not use_gpu else None)

            # Infer board dimensions from the loaded model
            inferred_width, inferred_height = self.infer_board_size_from_model(net_params)
            if inferred_width and inferred_height:
                self.policy_value_net = Net(inferred_width, inferred_height).to(self.device) if use_gpu else Net(
                    inferred_width, inferred_height)
                self.policy_value_net.load_state_dict(net_params)
                print("Use model file to initialize the policy value net")
            else:
                raise Exception("The model file does not contain the board dimensions")

            if inferred_width < board_width:
                self.use_conv = True
            elif inferred_width > board_width:
                raise Exception("The model file has a larger board size than the current board size!!")
        else:
            # the policy value net module
            if self.use_gpu:
                self.policy_value_net = Net(board_width, board_height).to(self.device)
            else:
                self.policy_value_net = Net(board_width, board_height)

        self.optimizer = optim.Adam(self.policy_value_net.parameters(),
                                    weight_decay=self.l2_const)

    def infer_board_size_from_model(self, model):
        # Use the size of the act_fc1 layer to infer board dimensions
        for name in model.keys():
            if name == 'act_fc1.weight':
                # Assuming the weight shape is [board_width * board_height, 4 * board_width * board_height]
                c, _ = model[name].shape
                print(f"act_fc1.weight shape: {model[name].shape}")
                board_size = int(c ** 0.5)  # Extracting board_width/height assuming they are the same
                print(f"Board size inferred from model: {board_size}x{board_size}")
                return board_size, board_size
        return None

    def apply_normal_bias(self, tensor, mean=0, std=1):
        bsize = tensor.shape[0]
        x, y = np.meshgrid(np.linspace(-1, 1, bsize), np.linspace(-1, 1, bsize))
        d = np.sqrt(x * x + y * y)
        sigma, mu = 1.0, 0.0
        gauss = np.exp(-((d - mu) ** 2 / (2.0 * sigma ** 2)))
        # Applying the bias only to non-zero elements
        biased_tensor = tensor - (tensor != 0) * gauss
        return biased_tensor

    def policy_value(self, state_batch):
        """
        input: a batch of states
        output: a batch of action probabilities and state values
        """
        if self.use_gpu:
            state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
            log_act_probs, value = self.policy_value_net(state_batch)
            act_probs = np.exp(log_act_probs.data.cpu().numpy())
            return act_probs, value.data.cpu().numpy()
        else:
            state_batch = Variable(torch.FloatTensor(state_batch))
            log_act_probs, value = self.policy_value_net(state_batch)
            act_probs = np.exp(log_act_probs.data.numpy())
            return act_probs, value.data.numpy()

    def policy_value_fn(self, board):
        """
        input: board
        output: a list of (action, probability) tuples for each available
        action and the score of the board state
        """
        legal_positions = board.availables
        current_state = np.ascontiguousarray(board.current_state().reshape(
            -1, 4, self.board_width, self.board_height))
        if self.bias:
            current_state[0][1] = self.apply_normal_bias(current_state[0][1])

        if self.use_gpu:
            log_act_probs, value = self.policy_value_net(
                Variable(torch.from_numpy(current_state)).to(self.device).float())
            act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten())
        else:
            log_act_probs, value = self.policy_value_net(
                Variable(torch.from_numpy(current_state)).float())
            act_probs = np.exp(log_act_probs.data.numpy().flatten())
        act_probs = zip(legal_positions, act_probs[legal_positions])
        value = value.data[0][0]
        return act_probs, value

    def train_step(self, state_batch, mcts_probs, winner_batch, lr):
        """perform a training step"""

        # self.use_gpu = True
        # wrap in Variable
        if self.use_gpu:
            state_batch = Variable(torch.FloatTensor(state_batch).to(self.device))
            mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(self.device))
            winner_batch = Variable(torch.FloatTensor(winner_batch).to(self.device))
        else:
            state_batch = Variable(torch.FloatTensor(state_batch))
            mcts_probs = Variable(torch.FloatTensor(mcts_probs))
            winner_batch = Variable(torch.FloatTensor(winner_batch))

        # zero the parameter gradients
        self.optimizer.zero_grad()
        # set learning rate
        set_learning_rate(self.optimizer, lr)

        # forward
        log_act_probs, value = self.policy_value_net(state_batch)
        # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
        # Note: the L2 penalty is incorporated in optimizer
        value_loss = F.mse_loss(value.view(-1), winner_batch)
        policy_loss = -torch.mean(torch.sum(mcts_probs * log_act_probs, 1))
        loss = value_loss + policy_loss
        # backward and optimize
        loss.backward()
        self.optimizer.step()
        # calc policy entropy, for monitoring only
        entropy = -torch.mean(
            torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
        )

        # for pytorch version >= 0.5 please use the following line instead.
        return loss.item(), entropy.item()

    def get_policy_param(self):
        net_params = self.policy_value_net.state_dict()
        return net_params

    def save_model(self, model_file):
        """ save model params to file """
        net_params = self.get_policy_param()  # get model params
        torch.save(net_params, model_file)


if __name__ == "__main__":
    import torch
    import torch.onnx

    # 假设您的 Net 模型已经定义好了
    model = Net(board_width=9, board_height=9)  # 使用适当的参数初始化模型
    dummy_input = torch.randn(1, 4, 9, 9)  # 创建一个示例输入

    # 将模型导出到 ONNX 格式
    torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)