Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import numpy as np | |
def set_learning_rate(optimizer, lr): | |
"""Sets the learning rate to the given value""" | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
class DuelingDQNNet(nn.Module): | |
"""Dueling DQN network module""" | |
def __init__(self, board_width, board_height): | |
super(DuelingDQNNet, self).__init__() | |
self.board_width = board_width | |
self.board_height = board_height | |
# common layers | |
self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) | |
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | |
# advantage layers | |
self.adv_conv1 = nn.Conv2d(128, 4, kernel_size=1) | |
self.adv_fc1 = nn.Linear(4*board_width*board_height, | |
board_width*board_height) | |
# value layers | |
self.val_conv1 = nn.Conv2d(128, 2, kernel_size=1) | |
self.val_fc1 = nn.Linear(2*board_width*board_height, 64) | |
self.val_fc2 = nn.Linear(64, 1) | |
def forward(self, state_input): | |
# common layers | |
x = F.relu(self.conv1(state_input)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
# advantage stream | |
adv = F.relu(self.adv_conv1(x)) | |
adv = adv.view(-1, 4*self.board_width*self.board_height) | |
adv = self.adv_fc1(adv) | |
# value stream | |
val = F.relu(self.val_conv1(x)) | |
val = val.view(-1, 2*self.board_width*self.board_height) | |
val = F.relu(self.val_fc1(val)) | |
val = self.val_fc2(val) | |
q_values = val + adv - adv.mean(dim=1, keepdim=True) | |
return F.log_softmax(q_values, dim=1), val | |
class PolicyValueNet(): | |
"""policy-value network """ | |
def __init__(self, board_width, board_height, | |
model_file=None, use_gpu=False, device = None): | |
self.use_gpu = use_gpu | |
self.board_width = board_width | |
self.board_height = board_height | |
self.l2_const = 1e-4 # coef of l2 penalty | |
# the policy value net module | |
if self.use_gpu: | |
self.policy_value_net = DuelingDQNNet(board_width, board_height).to(device) | |
else: | |
self.policy_value_net = DuelingDQNNet(board_width, board_height) | |
self.optimizer = optim.Adam(self.policy_value_net.parameters(), | |
weight_decay=self.l2_const) | |
if model_file: | |
net_params = torch.load(model_file) | |
self.policy_value_net.load_state_dict(net_params, strict=False) | |
def policy_value(self, state_batch): | |
""" | |
input: a batch of states | |
output: a batch of action probabilities and state values | |
""" | |
if self.use_gpu: | |
state_batch = Variable(torch.FloatTensor(state_batch).to(device)) | |
log_act_probs, value = self.policy_value_net(state_batch) | |
act_probs = np.exp(log_act_probs.data.cpu().numpy()) | |
return act_probs, value.data.cpu().numpy() | |
else: | |
state_batch = Variable(torch.FloatTensor(state_batch)) | |
log_act_probs, value = self.policy_value_net(state_batch) | |
act_probs = np.exp(log_act_probs.data.numpy()) | |
return act_probs, value.data.numpy() | |
def policy_value_fn(self, board): | |
""" | |
input: board | |
output: a list of (action, probability) tuples for each available | |
action and the score of the board state | |
""" | |
legal_positions = board.availables | |
current_state = np.ascontiguousarray(board.current_state().reshape( | |
-1, 4, self.board_width, self.board_height)) | |
if self.use_gpu: | |
log_act_probs, value = self.policy_value_net( | |
Variable(torch.from_numpy(current_state)).to(device).float()) | |
act_probs = np.exp(log_act_probs.data.cpu().numpy().flatten()) | |
else: | |
log_act_probs, value = self.policy_value_net( | |
Variable(torch.from_numpy(current_state)).float()) | |
act_probs = np.exp(log_act_probs.data.numpy().flatten()) | |
act_probs = zip(legal_positions, act_probs[legal_positions]) | |
return act_probs, value | |
def train_step(self, state_batch, mcts_probs, winner_batch, lr): | |
"""perform a training step""" | |
# self.use_gpu = True | |
# wrap in Variable | |
if self.use_gpu: | |
state_batch = Variable(torch.FloatTensor(state_batch).to(device)) | |
mcts_probs = Variable(torch.FloatTensor(mcts_probs).to(device)) | |
winner_batch = Variable(torch.FloatTensor(winner_batch).to(device)) | |
else: | |
state_batch = Variable(torch.FloatTensor(state_batch)) | |
mcts_probs = Variable(torch.FloatTensor(mcts_probs)) | |
winner_batch = Variable(torch.FloatTensor(winner_batch)) | |
# zero the parameter gradients | |
self.optimizer.zero_grad() | |
# set learning rate | |
set_learning_rate(self.optimizer, lr) | |
# forward | |
log_act_probs, value = self.policy_value_net(state_batch) | |
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2 | |
# Note: the L2 penalty is incorporated in optimizer | |
value_loss = F.mse_loss(value.view(-1), winner_batch) | |
policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1)) | |
loss = value_loss + policy_loss | |
# backward and optimize | |
loss.backward() | |
self.optimizer.step() | |
# calc policy entropy, for monitoring only | |
entropy = -torch.mean( | |
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1) | |
) | |
# return loss.data[0], entropy.data[0] | |
#for pytorch version >= 0.5 please use the following line instead. | |
return loss.item(), entropy.item() | |
def get_policy_param(self): | |
net_params = self.policy_value_net.state_dict() | |
return net_params | |
def save_model(self, model_file): | |
""" save model params to file """ | |
net_params = self.get_policy_param() # get model params | |
torch.save(net_params, model_file) |