# -*- coding: utf-8 -*- """ An implementation of the policyValueNet in PyTorch Tested in PyTorch 0.2.0 and 0.3.0 @author: Junxiao Song """ import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.autograd import Variable import numpy as np class Net(nn.Module): """policy-value network module""" def __init__(self, board_width, board_height): super(Net, self).__init__() self.board_width = board_width self.board_height = board_height # common layers self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # action policy layers self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1) self.act_fc1 = nn.Linear(4*board_width*board_height, board_width*board_height) # state value layers self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1) self.val_fc1 = nn.Linear(2*board_width*board_height, 64) self.val_fc2 = nn.Linear(64, 1) def forward(self, state_input): # common layers x = F.relu(self.conv1(state_input)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) # action policy layers x_act = F.relu(self.act_conv1(x)) x_act = x_act.view(-1, 4*self.board_width*self.board_height) x_act = F.log_softmax(self.act_fc1(x_act)) # state value layers x_val = F.relu(self.val_conv1(x)) x_val = x_val.view(-1, 2*self.board_width*self.board_height) x_val = F.relu(self.val_fc1(x_val)) x_val = F.tanh(self.val_fc2(x_val)) return x_act, x_val class PolicyValueNet(): """alphazero policy-value network """ def __init__(self, board_width, board_height, model_file=None, use_gpu=False): self.use_gpu = use_gpu self.board_width = board_width self.board_height = board_height # the policy value net module if self.use_gpu: self.policy_value_net = Net(board_width, board_height).cuda() else: self.policy_value_net = Net(board_width, board_height) if model_file: net_params = torch.load(model_file) self.policy_value_net.load_state_dict(net_params) print('loaded model file') def policy_value(self, state_batch): """ input: a batch of states output: a batch of action probabilities and state values """ if self.use_gpu: state_batch = Variable(torch.FloatTensor(state_batch).cuda()) log_act_probs, value = self.policy_value_net(state_batch) act_probs = np.exp(log_act_probs.data.cpu().numpy()) return act_probs, value.data.cpu().numpy() else: state_batch = Variable(torch.FloatTensor(state_batch)) log_act_probs, value = self.policy_value_net(state_batch) act_probs = np.exp(log_act_probs.data.numpy()) return act_probs, value.data.numpy() def policy_value_fn(self, board): """ input: board output: a list of (action, probability) tuples for each available action and the score of the board state """ legal_positions = board.availables current_state = np.ascontiguousarray(board.current_state().reshape( -1, 4, self.board_width, self.board_height)) if self.use_gpu: log_act_probs, value = self.policy_value_net( Variable(torch.from_numpy(current_state)).cuda().float()) act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten()) else: log_act_probs, value = self.policy_value_net( Variable(torch.from_numpy(current_state)).float()) act_probs = np.exp(log_act_probs.data.numpy().flatten()) act_probs = zip(legal_positions, act_probs[legal_positions]) value = value.data[0][0] return act_probs, value # 搬到main_worker def train_step(self, state_batch, mcts_probs, winner_batch, lr): """perform a training step""" # self.use_gpu = True # wrap in Variable if self.use_gpu: state_batch = Variable(torch.FloatTensor(state_batch).cuda()) mcts_probs = Variable(torch.FloatTensor(mcts_probs).cuda()) winner_batch = Variable(torch.FloatTensor(winner_batch).cuda()) else: state_batch = Variable(torch.FloatTensor(state_batch)) mcts_probs = Variable(torch.FloatTensor(mcts_probs)) winner_batch = Variable(torch.FloatTensor(winner_batch)) # zero the parameter gradients self.optimizer.zero_grad() # set learning rate set_learning_rate(self.optimizer, lr) # forward log_act_probs, value = self.policy_value_net(state_batch) # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2 # Note: the L2 penalty is incorporated in optimizer value_loss = F.mse_loss(value.view(-1), winner_batch) policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1)) loss = value_loss + policy_loss # backward and optimize loss.backward() self.optimizer.step() # calc policy entropy, for monitoring only entropy = -torch.mean( torch.sum(torch.exp(log_act_probs) * log_act_probs, 1) ) # return loss.data[0], entropy.data[0] #for pytorch version >= 0.5 please use the following line instead. return loss.item(), entropy.item() # def get_policy_param(self): # net_params = self.policy_value_net.state_dict() # return net_params # def save_model(self, model_file): # """ save model params to file """ # net_params = self.get_policy_param() # get model params # torch.save(net_params, model_file)