import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.autograd import Variable import numpy as np def set_learning_rate(optimizer, lr): """Sets the learning rate to the given value""" for param_group in optimizer.param_groups: param_group['lr'] = lr class DuelingDQNNet(nn.Module): """Dueling DQN network module""" def __init__(self, board_width, board_height): super(DuelingDQNNet, self).__init__() self.board_width = board_width self.board_height = board_height # common layers self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # advantage layers self.adv_conv1 = nn.Conv2d(128, 4, kernel_size=1) self.adv_fc1 = nn.Linear(4*board_width*board_height, board_width*board_height) # value layers self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1) self.val_fc1 = nn.Linear(2*board_width*board_height, 64) self.val_fc2 = nn.Linear(64, 1) def forward(self, state_input): # common layers x = F.relu(self.conv1(state_input)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) # advantage stream adv = F.relu(self.adv_conv1(x)) adv = adv.view(-1, 4*self.board_width*self.board_height) adv = self.adv_fc1(adv) # value stream val = F.relu(self.val_conv1(x)) val = val.view(-1, 2*self.board_width*self.board_height) val = F.relu(self.val_fc1(val)) val = self.val_fc2(val) q_values = val + adv - adv.mean(dim=1, keepdim=True) return F.log_softmax(q_values, dim=1), val class PolicyValueNet(): """policy-value network """ def __init__(self, board_width, board_height, model_file=None, use_gpu=False, device = None): self.use_gpu = use_gpu self.board_width = board_width self.board_height = board_height self.l2_const = 1e-4 # coef of l2 penalty # the policy value net module if self.use_gpu: self.policy_value_net = DuelingDQNNet(board_width, board_height).to(device) else: self.policy_value_net = DuelingDQNNet(board_width, board_height) self.optimizer = optim.Adam(self.policy_value_net.parameters(), weight_decay=self.l2_const) if model_file: net_params = torch.load(model_file) self.policy_value_net.load_state_dict(net_params, strict=False) def policy_value(self, state_batch): """ input: a batch of states output: a batch of action probabilities and state values """ if self.use_gpu: state_batch = Variable(torch.FloatTensor(state_batch).to(device)) log_act_probs, value = self.policy_value_net(state_batch) act_probs = np.exp(log_act_probs.data.cpu().numpy()) return act_probs, value.data.cpu().numpy() else: state_batch = Variable(torch.FloatTensor(state_batch)) log_act_probs, value = self.policy_value_net(state_batch) act_probs = np.exp(log_act_probs.data.numpy()) return act_probs, value.data.numpy() def policy_value_fn(self, board): """ input: board output: a list of (action, probability) tuples for each available action and the score of the board state """ legal_positions = board.availables current_state = np.ascontiguousarray(board.current_state().reshape( -1, 4, self.board_width, self.board_height)) if self.use_gpu: log_act_probs, value = self.policy_value_net( Variable(torch.from_numpy(current_state)).to(device).float()) act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten()) else: log_act_probs, value = self.policy_value_net( Variable(torch.from_numpy(current_state)).float()) act_probs = np.exp(log_act_probs.data.numpy().flatten()) act_probs = zip(legal_positions, act_probs[legal_positions]) return act_probs, value def train_step(self, state_batch, mcts_probs, winner_batch, lr): """perform a training step""" # self.use_gpu = True # wrap in Variable if self.use_gpu: state_batch = Variable(torch.FloatTensor(state_batch).to(device)) mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(device)) winner_batch = Variable(torch.FloatTensor(winner_batch).to(device)) else: state_batch = Variable(torch.FloatTensor(state_batch)) mcts_probs = Variable(torch.FloatTensor(mcts_probs)) winner_batch = Variable(torch.FloatTensor(winner_batch)) # zero the parameter gradients self.optimizer.zero_grad() # set learning rate set_learning_rate(self.optimizer, lr) # forward log_act_probs, value = self.policy_value_net(state_batch) # define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2 # Note: the L2 penalty is incorporated in optimizer value_loss = F.mse_loss(value.view(-1), winner_batch) policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1)) loss = value_loss + policy_loss # backward and optimize loss.backward() self.optimizer.step() # calc policy entropy, for monitoring only entropy = -torch.mean( torch.sum(torch.exp(log_act_probs) * log_act_probs, 1) ) # return loss.data[0], entropy.data[0] #for pytorch version >= 0.5 please use the following line instead. return loss.item(), entropy.item() def get_policy_param(self): net_params = self.policy_value_net.state_dict() return net_params def save_model(self, model_file): """ save model params to file """ net_params = self.get_policy_param() # get model params torch.save(net_params, model_file)