JackAILab's picture
Upload 292 files
9669aec verified
raw
history blame
5.21 kB
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from logger import setup_logger
from model import BiSeNet
from face_dataset import FaceMask
from loss import OhemCELoss
from evaluate import evaluate
from optimizer import Optimizer
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.distributed as dist
import os
import os.path as osp
import logging
import time
import datetime
import argparse
respth = './res'
if not osp.exists(respth):
os.makedirs(respth)
logger = logging.getLogger()
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument(
'--local_rank',
dest = 'local_rank',
type = int,
default = -1,
)
return parse.parse_args()
def train():
args = parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend = 'nccl',
init_method = 'tcp://127.0.0.1:33241',
world_size = torch.cuda.device_count(),
rank=args.local_rank
)
setup_logger(respth)
# dataset
n_classes = 19
n_img_per_gpu = 16
n_workers = 8
cropsize = [448, 448]
data_root = '/home/zll/data/CelebAMask-HQ/'
ds = FaceMask(data_root, cropsize=cropsize, mode='train')
sampler = torch.utils.data.distributed.DistributedSampler(ds)
dl = DataLoader(ds,
batch_size = n_img_per_gpu,
shuffle = False,
sampler = sampler,
num_workers = n_workers,
pin_memory = True,
drop_last = True)
# model
ignore_idx = -100
net = BiSeNet(n_classes=n_classes)
net.cuda()
net.train()
net = nn.parallel.DistributedDataParallel(net,
device_ids = [args.local_rank, ],
output_device = args.local_rank
)
score_thres = 0.7
n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16
LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
## optimizer
momentum = 0.9
weight_decay = 5e-4
lr_start = 1e-2
max_iter = 80000
power = 0.9
warmup_steps = 1000
warmup_start_lr = 1e-5
optim = Optimizer(
model = net.module,
lr0 = lr_start,
momentum = momentum,
wd = weight_decay,
warmup_steps = warmup_steps,
warmup_start_lr = warmup_start_lr,
max_iter = max_iter,
power = power)
## train loop
msg_iter = 50
loss_avg = []
st = glob_st = time.time()
diter = iter(dl)
epoch = 0
for it in range(max_iter):
try:
im, lb = next(diter)
if not im.size()[0] == n_img_per_gpu:
raise StopIteration
except StopIteration:
epoch += 1
sampler.set_epoch(epoch)
diter = iter(dl)
im, lb = next(diter)
im = im.cuda()
lb = lb.cuda()
H, W = im.size()[2:]
lb = torch.squeeze(lb, 1)
optim.zero_grad()
out, out16, out32 = net(im)
lossp = LossP(out, lb)
loss2 = Loss2(out16, lb)
loss3 = Loss3(out32, lb)
loss = lossp + loss2 + loss3
loss.backward()
optim.step()
loss_avg.append(loss.item())
# print training log message
if (it+1) % msg_iter == 0:
loss_avg = sum(loss_avg) / len(loss_avg)
lr = optim.lr
ed = time.time()
t_intv, glob_t_intv = ed - st, ed - glob_st
eta = int((max_iter - it) * (glob_t_intv / it))
eta = str(datetime.timedelta(seconds=eta))
msg = ', '.join([
'it: {it}/{max_it}',
'lr: {lr:4f}',
'loss: {loss:.4f}',
'eta: {eta}',
'time: {time:.4f}',
]).format(
it = it+1,
max_it = max_iter,
lr = lr,
loss = loss_avg,
time = t_intv,
eta = eta
)
logger.info(msg)
loss_avg = []
st = ed
if dist.get_rank() == 0:
if (it+1) % 5000 == 0:
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
if dist.get_rank() == 0:
torch.save(state, './res/cp/{}_iter.pth'.format(it))
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it))
# dump the final model
save_pth = osp.join(respth, 'model_final_diss.pth')
# net.cpu()
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
if dist.get_rank() == 0:
torch.save(state, save_pth)
logger.info('training done, model saved to: {}'.format(save_pth))
if __name__ == "__main__":
train()