Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/python | |
# -*- encoding: utf-8 -*- | |
from logger import setup_logger | |
from model import BiSeNet | |
from face_dataset import FaceMask | |
from loss import OhemCELoss | |
from evaluate import evaluate | |
from optimizer import Optimizer | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
import os | |
import os.path as osp | |
import logging | |
import time | |
import datetime | |
import argparse | |
respth = './res' | |
if not osp.exists(respth): | |
os.makedirs(respth) | |
logger = logging.getLogger() | |
def parse_args(): | |
parse = argparse.ArgumentParser() | |
parse.add_argument( | |
'--local_rank', | |
dest = 'local_rank', | |
type = int, | |
default = -1, | |
) | |
return parse.parse_args() | |
def train(): | |
args = parse_args() | |
torch.cuda.set_device(args.local_rank) | |
dist.init_process_group( | |
backend = 'nccl', | |
init_method = 'tcp://127.0.0.1:33241', | |
world_size = torch.cuda.device_count(), | |
rank=args.local_rank | |
) | |
setup_logger(respth) | |
# dataset | |
n_classes = 19 | |
n_img_per_gpu = 16 | |
n_workers = 8 | |
cropsize = [448, 448] | |
data_root = '/home/zll/data/CelebAMask-HQ/' | |
ds = FaceMask(data_root, cropsize=cropsize, mode='train') | |
sampler = torch.utils.data.distributed.DistributedSampler(ds) | |
dl = DataLoader(ds, | |
batch_size = n_img_per_gpu, | |
shuffle = False, | |
sampler = sampler, | |
num_workers = n_workers, | |
pin_memory = True, | |
drop_last = True) | |
# model | |
ignore_idx = -100 | |
net = BiSeNet(n_classes=n_classes) | |
net.cuda() | |
net.train() | |
net = nn.parallel.DistributedDataParallel(net, | |
device_ids = [args.local_rank, ], | |
output_device = args.local_rank | |
) | |
score_thres = 0.7 | |
n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 | |
LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) | |
Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) | |
Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) | |
## optimizer | |
momentum = 0.9 | |
weight_decay = 5e-4 | |
lr_start = 1e-2 | |
max_iter = 80000 | |
power = 0.9 | |
warmup_steps = 1000 | |
warmup_start_lr = 1e-5 | |
optim = Optimizer( | |
model = net.module, | |
lr0 = lr_start, | |
momentum = momentum, | |
wd = weight_decay, | |
warmup_steps = warmup_steps, | |
warmup_start_lr = warmup_start_lr, | |
max_iter = max_iter, | |
power = power) | |
## train loop | |
msg_iter = 50 | |
loss_avg = [] | |
st = glob_st = time.time() | |
diter = iter(dl) | |
epoch = 0 | |
for it in range(max_iter): | |
try: | |
im, lb = next(diter) | |
if not im.size()[0] == n_img_per_gpu: | |
raise StopIteration | |
except StopIteration: | |
epoch += 1 | |
sampler.set_epoch(epoch) | |
diter = iter(dl) | |
im, lb = next(diter) | |
im = im.cuda() | |
lb = lb.cuda() | |
H, W = im.size()[2:] | |
lb = torch.squeeze(lb, 1) | |
optim.zero_grad() | |
out, out16, out32 = net(im) | |
lossp = LossP(out, lb) | |
loss2 = Loss2(out16, lb) | |
loss3 = Loss3(out32, lb) | |
loss = lossp + loss2 + loss3 | |
loss.backward() | |
optim.step() | |
loss_avg.append(loss.item()) | |
# print training log message | |
if (it+1) % msg_iter == 0: | |
loss_avg = sum(loss_avg) / len(loss_avg) | |
lr = optim.lr | |
ed = time.time() | |
t_intv, glob_t_intv = ed - st, ed - glob_st | |
eta = int((max_iter - it) * (glob_t_intv / it)) | |
eta = str(datetime.timedelta(seconds=eta)) | |
msg = ', '.join([ | |
'it: {it}/{max_it}', | |
'lr: {lr:4f}', | |
'loss: {loss:.4f}', | |
'eta: {eta}', | |
'time: {time:.4f}', | |
]).format( | |
it = it+1, | |
max_it = max_iter, | |
lr = lr, | |
loss = loss_avg, | |
time = t_intv, | |
eta = eta | |
) | |
logger.info(msg) | |
loss_avg = [] | |
st = ed | |
if dist.get_rank() == 0: | |
if (it+1) % 5000 == 0: | |
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() | |
if dist.get_rank() == 0: | |
torch.save(state, './res/cp/{}_iter.pth'.format(it)) | |
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it)) | |
# dump the final model | |
save_pth = osp.join(respth, 'model_final_diss.pth') | |
# net.cpu() | |
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() | |
if dist.get_rank() == 0: | |
torch.save(state, save_pth) | |
logger.info('training done, model saved to: {}'.format(save_pth)) | |
if __name__ == "__main__": | |
train() | |