Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
An implementation of the policyValueNet in PyTorch | |
Tested in PyTorch 0.2.0 and 0.3.0 | |
@author: Junxiao Song | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import numpy as np | |
def set_learning_rate(optimizer, lr): | |
"""Sets the learning rate to the given value""" | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
class ResidualBlock(nn.Module): | |
def __init__(self, channels): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
self.bn1 = nn.BatchNorm2d(channels) | |
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
self.bn2 = nn.BatchNorm2d(channels) | |
def forward(self, x): | |
residual = x | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out += residual | |
return F.relu(out) | |
class Net(nn.Module): | |
"""Policy-Value network module for AlphaZero Gomoku.""" | |
def __init__(self, board_width, board_height, num_residual_blocks=5): | |
super(Net, self).__init__() | |
self.board_width = board_width | |
self.board_height = board_height | |
self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1) | |
self.bn1 = nn.BatchNorm2d(32) | |
self.res_layers = nn.Sequential(*[ResidualBlock(32) for _ in range(num_residual_blocks)]) | |
# Action Policy layers | |
self.act_conv1 = nn.Conv2d(32, 4, kernel_size=1) | |
self.act_fc1 = nn.Linear(4 * board_width * board_height, board_width * board_height) | |
# State Value layers | |
self.val_conv1 = nn.Conv2d(32, 2, kernel_size=1) | |
self.val_fc1 = nn.Linear(2 * board_width * board_height, 64) | |
self.val_fc2 = nn.Linear(64, 1) | |
def forward(self, x): | |
x = F.relu(self.bn1(self.conv1(x))) | |
x = self.res_layers(x) | |
# Action Policy head | |
x_act = F.relu(self.act_conv1(x)) | |
x_act = x_act.view(-1, 4 * self.board_width * self.board_height) | |
x_act = F.log_softmax(self.act_fc1(x_act), dim=1) | |
# State Value head | |
x_val = F.relu(self.val_conv1(x)) | |
x_val = x_val.view(-1, 2 * self.board_width * self.board_height) | |
x_val = F.relu(self.val_fc1(x_val)) | |
x_val = torch.tanh(self.val_fc2(x_val)) | |
return x_act, x_val | |
class PolicyValueNet(): | |
"""policy-value network """ | |
def __init__(self, board_width, board_height, | |
model_file=None, use_gpu=False, bias=False): | |
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
self.use_gpu = use_gpu | |
self.l2_const = 1e-4 # coef of l2 penalty | |
self.board_width = board_width | |
self.board_height = board_height | |
self.bias = bias | |
if model_file: | |
net_params = torch.load(model_file, map_location='cpu' if not use_gpu else None) | |
# Infer board dimensions from the loaded model | |
inferred_width, inferred_height = self.infer_board_size_from_model(net_params) | |
if inferred_width and inferred_height: | |
self.policy_value_net = Net(inferred_width, inferred_height).to(self.device) if use_gpu else Net( | |
inferred_width, inferred_height) | |
self.policy_value_net.load_state_dict(net_params) | |
print("Use model file to initialize the policy value net") | |
else: | |
raise Exception("The model file does not contain the board dimensions") | |
if inferred_width < board_width: | |
self.use_conv = True | |
elif inferred_width > board_width: | |
raise Exception("The model file has a larger board size than the current board size!!") | |
else: | |
# the policy value net module | |
if self.use_gpu: | |
self.policy_value_net = Net(board_width, board_height).to(self.device) | |
else: | |
self.policy_value_net = Net(board_width, board_height) | |
self.optimizer = optim.Adam(self.policy_value_net.parameters(), | |
weight_decay=self.l2_const) | |
def infer_board_size_from_model(self, model): | |
# Use the size of the act_fc1 layer to infer board dimensions | |
for name in model.keys(): | |
if name == 'act_fc1.weight': | |
# Assuming the weight shape is [board_width * board_height, 4 * board_width * board_height] | |
c, _ = model[name].shape | |
print(f"act_fc1.weight shape: {model[name].shape}") | |
board_size = int(c ** 0.5) # Extracting board_width/height assuming they are the same | |
print(f"Board size inferred from model: {board_size}x{board_size}") | |
return board_size, board_size | |
return None | |
def apply_normal_bias(self, tensor, mean=0, std=1): | |
bsize = tensor.shape[0] | |
x, y = np.meshgrid(np.linspace(-1, 1, bsize), np.linspace(-1, 1, bsize)) | |
d = np.sqrt(x * x + y * y) | |
sigma, mu = 1.0, 0.0 | |
gauss = np.exp(-((d - mu) ** 2 / (2.0 * sigma ** 2))) | |
# Applying the bias only to non-zero elements | |
biased_tensor = tensor - (tensor != 0) * gauss | |
return biased_tensor | |
def policy_value(self, state_batch): | |
""" | |
input: a batch of states | |
output: a batch of action probabilities and state values | |
""" | |
if self.use_gpu: | |
state_batch = Variable(torch.FloatTensor(state_batch).to(self.device)) | |
log_act_probs, value = self.policy_value_net(state_batch) | |
act_probs = np.exp(log_act_probs.data.cpu().numpy()) | |
return act_probs, value.data.cpu().numpy() | |
else: | |
state_batch = Variable(torch.FloatTensor(state_batch)) | |
log_act_probs, value = self.policy_value_net(state_batch) | |
act_probs = np.exp(log_act_probs.data.numpy()) | |
return act_probs, value.data.numpy() | |
def policy_value_fn(self, board): | |
""" | |
input: board | |
output: a list of (action, probability) tuples for each available | |
action and the score of the board state | |
""" | |
legal_positions = board.availables | |
current_state = np.ascontiguousarray(board.current_state().reshape( | |
-1, 4, self.board_width, self.board_height)) | |
if self.bias: | |
current_state[0][1] = self.apply_normal_bias(current_state[0][1]) | |
if self.use_gpu: | |
log_act_probs, value = self.policy_value_net( | |
Variable(torch.from_numpy(current_state)).to(self.device).float()) | |
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten()) | |
else: | |
log_act_probs, value = self.policy_value_net( | |
Variable(torch.from_numpy(current_state)).float()) | |
act_probs = np.exp(log_act_probs.data.numpy().flatten()) | |
act_probs = zip(legal_positions, act_probs[legal_positions]) | |
value = value.data[0][0] | |
return act_probs, value | |
def train_step(self, state_batch, mcts_probs, winner_batch, lr): | |
"""perform a training step""" | |
# self.use_gpu = True | |
# wrap in Variable | |
if self.use_gpu: | |
state_batch = Variable(torch.FloatTensor(state_batch).to(self.device)) | |
mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(self.device)) | |
winner_batch = Variable(torch.FloatTensor(winner_batch).to(self.device)) | |
else: | |
state_batch = Variable(torch.FloatTensor(state_batch)) | |
mcts_probs = Variable(torch.FloatTensor(mcts_probs)) | |
winner_batch = Variable(torch.FloatTensor(winner_batch)) | |
# zero the parameter gradients | |
self.optimizer.zero_grad() | |
# set learning rate | |
set_learning_rate(self.optimizer, lr) | |
# forward | |
log_act_probs, value = self.policy_value_net(state_batch) | |
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2 | |
# Note: the L2 penalty is incorporated in optimizer | |
value_loss = F.mse_loss(value.view(-1), winner_batch) | |
policy_loss = -torch.mean(torch.sum(mcts_probs * log_act_probs, 1)) | |
loss = value_loss + policy_loss | |
# backward and optimize | |
loss.backward() | |
self.optimizer.step() | |
# calc policy entropy, for monitoring only | |
entropy = -torch.mean( | |
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1) | |
) | |
# for pytorch version >= 0.5 please use the following line instead. | |
return loss.item(), entropy.item() | |
def get_policy_param(self): | |
net_params = self.policy_value_net.state_dict() | |
return net_params | |
def save_model(self, model_file): | |
""" save model params to file """ | |
net_params = self.get_policy_param() # get model params | |
torch.save(net_params, model_file) | |
if __name__ == "__main__": | |
import torch | |
import torch.onnx | |
# 假设您的 Net 模型已经定义好了 | |
model = Net(board_width=9, board_height=9) # 使用适当的参数初始化模型 | |
dummy_input = torch.randn(1, 4, 9, 9) # 创建一个示例输入 | |
# 将模型导出到 ONNX 格式 | |
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True) | |