Demo / Gomoku_MCTS /main_worker.py
HuskyDoge's picture
trial
172a1e4
raw
history blame
13.1 kB
from __future__ import print_function
import random
import numpy as np
from collections import defaultdict, deque
from game import Board, Game
from mcts_pure import MCTSPlayer as MCTS_Pure
from mcts_alphaZero import MCTSPlayer
import torch.optim as optim
# from policy_value_net import PolicyValueNet # Theano and Lasagne
# from policy_value_net_pytorch import PolicyValueNet # Pytorch
from dueling_net import PolicyValueNet
# from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
# from policy_value_net_keras import PolicyValueNet # Keras
# import joblib
from torch.autograd import Variable
import torch.nn.functional as F
from config.options import *
import sys
from config.utils import *
from torch.backends import cudnn
import torch
from tqdm import *
from torch.utils.tensorboard import SummaryWriter
from multiprocessing import Pool
def set_learning_rate(optimizer, lr):
"""Sets the learning rate to the given value"""
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def std_log():
if get_rank() == 0:
save_path = make_path()
makedir(config['log_base'])
sys.stdout = open(os.path.join(config['log_base'], "{}.txt".format(save_path)), "w")
def init_seeds(seed, cuda_deterministic=True):
torch.manual_seed(seed)
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
class MainWorker():
def __init__(self,device):
#--- init the set of pipeline -------
self.board_width = opts.board_width
self.board_height = opts.board_height
self.n_in_row = opts.n_in_row
self.learn_rate = opts.learn_rate
self.lr_multiplier = opts.lr_multiplier
self.temp = opts.temp
self.n_playout = opts.n_playout
self.c_puct = opts.c_puct
self.buffer_size = opts.buffer_size
self.batch_size = opts.batch_size
self.play_batch_size = opts.play_batch_size
self.epochs = opts.epochs
self.kl_targ = opts.kl_targ
self.check_freq = opts.check_freq
self.game_batch_num = opts.game_batch_num
self.pure_mcts_playout_num = opts.pure_mcts_playout_num
self.device = device
self.use_gpu = torch.device("cuda") == self.device
self.board = Board(width=self.board_width,
height=self.board_height,
n_in_row=self.n_in_row)
self.game = Game(self.board)
# The data collection of the history of games
self.data_buffer = deque(maxlen=self.buffer_size)
# The best win ratio of the training agent
self.best_win_ratio = 0.0
if opts.preload_model:
# start training from an initial policy-value net
self.policy_value_net = PolicyValueNet(self.board_width,
self.board_height,
model_file=opts.preload_model,
use_gpu=(self.device == "cuda"))
else:
# start training from a new policy-value net
self.policy_value_net = PolicyValueNet(self.board_width,
self.board_height,
use_gpu=(self.device == "cuda"))
self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
c_puct=self.c_puct,
n_playout=self.n_playout,
is_selfplay=1)
# The set of optimizer
self.optimizer = optim.Adam(self.policy_value_net.policy_value_net.parameters(),
weight_decay=opts.l2_const)
# set learning rate
set_learning_rate(self.optimizer, self.learn_rate*self.lr_multiplier)
def get_equi_data(self, play_data):
"""augment the data set by rotation and flipping
play_data: [(state, mcts_prob, winner_z), ..., ...]
"""
extend_data = []
for state, mcts_porb, winner in play_data:
for i in [1, 2, 3, 4]:
# rotate counterclockwise
equi_state = np.array([np.rot90(s, i) for s in state])
equi_mcts_prob = np.rot90(np.flipud(
mcts_porb.reshape(self.board_height, self.board_width)), i)
extend_data.append((equi_state,
np.flipud(equi_mcts_prob).flatten(),
winner))
# flip horizontally
equi_state = np.array([np.fliplr(s) for s in equi_state])
equi_mcts_prob = np.fliplr(equi_mcts_prob)
extend_data.append((equi_state,
np.flipud(equi_mcts_prob).flatten(),
winner))
return extend_data
def job(self, i):
game = self.game
player = self.mcts_player
winner, play_data = game.start_self_play(player,
temp=self.temp)
play_data = list(play_data)[:]
play_data = self.get_equi_data(play_data)
return play_data
def collect_selfplay_data(self, n_games=1):
"""collect self-play data for training"""
# print("[STAGE] Collecting self-play data for training")
# collection_bar = tqdm( range(n_games))
collection_bar = range(n_games)
with Pool(4) as p:
play_data = p.map(self.job, collection_bar, chunksize=1)
self.data_buffer.extend(play_data)
# print('\n', 'data buffer size:', len(self.data_buffer))
def policy_update(self):
"""update the policy-value net"""
mini_batch = random.sample(self.data_buffer, self.batch_size)
state_batch = [data[0] for data in mini_batch]
mcts_probs_batch = [data[1] for data in mini_batch]
winner_batch = [data[2] for data in mini_batch]
old_probs, old_v = self.policy_value_net.policy_value(state_batch)
epoch_bar = tqdm(range(self.epochs))
for i in epoch_bar:
"""perform a training step"""
# wrap in Variable
if self.use_gpu:
state_batch = Variable(torch.FloatTensor(state_batch).cuda())
mcts_probs = Variable(torch.FloatTensor(mcts_probs_batch).cuda())
winner_batch = Variable(torch.FloatTensor(winner_batch).cuda())
else:
state_batch = Variable(torch.FloatTensor(state_batch))
mcts_probs = Variable(torch.FloatTensor(mcts_probs_batch))
winner_batch = Variable(torch.FloatTensor(winner_batch))
# zero the parameter gradients
self.optimizer.zero_grad()
# forward
log_act_probs, value = self.policy_value_net.policy_value_net(state_batch)
# define the loss = (z - v)^2 - pi^T * log(p) + c||theta||^2
# Note: the L2 penalty is incorporated in optimizer
value_loss = F.mse_loss(value.view(-1), winner_batch)
policy_loss = -torch.mean(torch.sum(mcts_probs*log_act_probs, 1))
loss = value_loss + policy_loss
# backward and optimize
loss.backward()
self.optimizer.step()
# calc policy entropy, for monitoring only
entropy = -torch.mean(
torch.sum(torch.exp(log_act_probs) * log_act_probs, 1)
)
loss = loss.item()
entropy = entropy.item()
new_probs, new_v = self.policy_value_net.policy_value(state_batch)
kl = np.mean(np.sum(old_probs * (
np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
axis=1)
)
if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly
break
epoch_bar.set_description(f"training epoch {i}")
epoch_bar.set_postfix( new_v =new_v, kl = kl)
# adaptively adjust the learning rate
if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
self.lr_multiplier /= 1.5
elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
self.lr_multiplier *= 1.5
explained_var_old = (1 -
np.var(np.array(winner_batch) - old_v.flatten()) /
np.var(np.array(winner_batch)))
explained_var_new = (1 -
np.var(np.array(winner_batch) - new_v.flatten()) /
np.var(np.array(winner_batch)))
return kl, loss, entropy,explained_var_old, explained_var_new
def policy_evaluate(self, n_games=10):
"""
Evaluate the trained policy by playing against the pure MCTS player
Note: this is only for monitoring the progress of training
"""
current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
c_puct=self.c_puct,
n_playout=self.n_playout)
pure_mcts_player = MCTS_Pure(c_puct=5,
n_playout=self.pure_mcts_playout_num)
win_cnt = defaultdict(int)
for i in range(n_games):
winner = self.game.start_play(
pure_mcts_player,current_mcts_player,
start_player=i % 2,
is_shown=0)
win_cnt[winner] += 1
print(f" {i}_th winner:" , winner)
win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
self.pure_mcts_playout_num,
win_cnt[1], win_cnt[2], win_cnt[-1]))
return win_ratio
def run(self):
"""run the training pipeline"""
try:
batch_bar = tqdm(range(self.game_batch_num))
for i in batch_bar:
self.collect_selfplay_data(self.play_batch_size)
if len(self.data_buffer) > self.batch_size:
kl, loss, entropy,explained_var_old, explained_var_new = self.policy_update()
writer.add_scalar("policy_update/kl", kl ,i )
writer.add_scalar("policy_update/loss", loss ,i)
writer.add_scalar("policy_update/entropy", entropy ,i)
writer.add_scalar("policy_update/explained_var_old", explained_var_old,i)
writer.add_scalar("policy_update/explained_var_new ", explained_var_new ,i)
batch_bar.set_description(f"game batch num {i}")
# check the performance of the current model,
# and save the model params
if (i+1) % self.check_freq == 0:
win_ratio = self.policy_evaluate()
batch_bar.set_description(f"game batch num {i+1}")
writer.add_scalar("evaluate/explained_var_new ", win_ratio ,i)
batch_bar.set_postfix(loss= loss, entropy= entropy,win_ratio =win_ratio)
save_model(self.policy_value_net,"current_policy.model")
if win_ratio > self.best_win_ratio:
print("New best policy!!!!!!!!")
self.best_win_ratio = win_ratio
# update the best_policy
save_model(self.policy_value_net,"best_policy.model")
if (self.best_win_ratio == 1.0 and
self.pure_mcts_playout_num < 5000):
self.pure_mcts_playout_num += 1000
self.best_win_ratio = 0.0
except KeyboardInterrupt:
print('\n\rquit')
if __name__ == "__main__":
print("START train....")
# ------init set-----------
if opts.std_log:
std_log()
writer = visualizer()
if opts.distributed:
torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
init_seeds(opts.seed + local_rank)
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
init_seeds(opts.seed)
print("seed: ",opts.seed )
print("device:" , device)
if opts.split == "train":
training_pipeline = MainWorker(device)
training_pipeline.run()
if get_rank() == 0 and opts.split == "test":
training_pipeline = MainWorker(device)
training_pipeline.policy_value_net()