|
import os |
|
import datetime |
|
import argparse |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.autograd import Variable |
|
|
|
from config import Config |
|
from loss import PixLoss, ClsLoss |
|
from dataset import MyData |
|
from models.birefnet import BiRefNet |
|
from utils import Logger, AverageMeter, set_seed, check_state_dict |
|
from evaluation.valid import valid |
|
|
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.distributed import init_process_group, destroy_process_group, get_rank |
|
from torch.cuda import amp |
|
|
|
|
|
parser = argparse.ArgumentParser(description='') |
|
parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint') |
|
parser.add_argument('--epochs', default=120, type=int) |
|
parser.add_argument('--trainset', default='DIS5K', type=str, help="Options: 'DIS5K'") |
|
parser.add_argument('--ckpt_dir', default=None, help='Temporary folder') |
|
parser.add_argument('--testsets', default='DIS-VD+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4', type=str) |
|
parser.add_argument('--dist', default=False, type=lambda x: x == 'True') |
|
args = parser.parse_args() |
|
|
|
|
|
config = Config() |
|
if config.rand_seed: |
|
set_seed(config.rand_seed) |
|
|
|
if config.use_fp16: |
|
|
|
scaler = amp.GradScaler(enabled=config.use_fp16) |
|
|
|
|
|
to_be_distributed = args.dist |
|
if to_be_distributed: |
|
init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=3600*10)) |
|
device = int(os.environ["LOCAL_RANK"]) |
|
else: |
|
device = config.device |
|
|
|
epoch_st = 1 |
|
|
|
os.makedirs(args.ckpt_dir, exist_ok=True) |
|
|
|
|
|
logger = Logger(os.path.join(args.ckpt_dir, "log.txt")) |
|
logger_loss_idx = 1 |
|
|
|
|
|
|
|
logger.info("datasets: load_all={}, compile={}.".format(config.load_all, config.compile)) |
|
logger.info("Other hyperparameters:"); logger.info(args) |
|
print('batch size:', config.batch_size) |
|
|
|
|
|
if os.path.exists(os.path.join(config.data_root_dir, config.task, args.testsets.strip('+').split('+')[0])): |
|
args.testsets = args.testsets.strip('+').split('+') |
|
else: |
|
args.testsets = [] |
|
|
|
|
|
def prepare_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, to_be_distributed=False, is_train=True): |
|
if to_be_distributed: |
|
return torch.utils.data.DataLoader( |
|
dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size), pin_memory=True, |
|
shuffle=False, sampler=DistributedSampler(dataset), drop_last=True |
|
) |
|
else: |
|
return torch.utils.data.DataLoader( |
|
dataset=dataset, batch_size=batch_size, num_workers=min(config.num_workers, batch_size, 0), pin_memory=True, |
|
shuffle=is_train, drop_last=True |
|
) |
|
|
|
|
|
def init_data_loaders(to_be_distributed): |
|
|
|
train_loader = prepare_dataloader( |
|
MyData(datasets=config.training_set, image_size=config.size, is_train=True), |
|
config.batch_size, to_be_distributed=to_be_distributed, is_train=True |
|
) |
|
print(len(train_loader), "batches of train dataloader {} have been created.".format(config.training_set)) |
|
test_loaders = {} |
|
for testset in args.testsets: |
|
_data_loader_test = prepare_dataloader( |
|
MyData(datasets=testset, image_size=config.size, is_train=False), |
|
config.batch_size_valid, is_train=False |
|
) |
|
print(len(_data_loader_test), "batches of valid dataloader {} have been created.".format(testset)) |
|
test_loaders[testset] = _data_loader_test |
|
return train_loader, test_loaders |
|
|
|
|
|
def init_models_optimizers(epochs, to_be_distributed): |
|
model = BiRefNet(bb_pretrained=True) |
|
if args.resume: |
|
if os.path.isfile(args.resume): |
|
logger.info("=> loading checkpoint '{}'".format(args.resume)) |
|
state_dict = torch.load(args.resume, map_location='cpu') |
|
state_dict = check_state_dict(state_dict) |
|
model.load_state_dict(state_dict) |
|
global epoch_st |
|
epoch_st = int(args.resume.rstrip('.pth').split('epoch_')[-1]) + 1 |
|
else: |
|
logger.info("=> no checkpoint found at '{}'".format(args.resume)) |
|
if to_be_distributed: |
|
model = model.to(device) |
|
model = DDP(model, device_ids=[device]) |
|
else: |
|
model = model.to(device) |
|
if config.compile: |
|
model = torch.compile(model, mode=['default', 'reduce-overhead', 'max-autotune'][0]) |
|
if config.precisionHigh: |
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
if config.optimizer == 'AdamW': |
|
optimizer = optim.AdamW(params=model.parameters(), lr=config.lr, weight_decay=1e-2) |
|
elif config.optimizer == 'Adam': |
|
optimizer = optim.Adam(params=model.parameters(), lr=config.lr, weight_decay=0) |
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer, |
|
milestones=[lde if lde > 0 else epochs + lde + 1 for lde in config.lr_decay_epochs], |
|
gamma=config.lr_decay_rate |
|
) |
|
logger.info("Optimizer details:"); logger.info(optimizer) |
|
logger.info("Scheduler details:"); logger.info(lr_scheduler) |
|
|
|
return model, optimizer, lr_scheduler |
|
|
|
|
|
class Trainer: |
|
def __init__( |
|
self, data_loaders, model_opt_lrsch, |
|
): |
|
self.model, self.optimizer, self.lr_scheduler = model_opt_lrsch |
|
self.train_loader, self.test_loaders = data_loaders |
|
if config.out_ref: |
|
self.criterion_gdt = nn.BCELoss() if not config.use_fp16 else nn.BCEWithLogitsLoss() |
|
|
|
|
|
self.pix_loss = PixLoss() |
|
self.cls_loss = ClsLoss() |
|
|
|
|
|
self.loss_log = AverageMeter() |
|
if config.lambda_adv_g: |
|
self.optimizer_d, self.lr_scheduler_d, self.disc, self.adv_criterion = self._load_adv_components() |
|
self.disc_update_for_odd = 0 |
|
|
|
def _load_adv_components(self): |
|
|
|
from loss import Discriminator |
|
disc = Discriminator(channels=3, img_size=config.size) |
|
if to_be_distributed: |
|
disc = disc.to(device) |
|
disc = DDP(disc, device_ids=[device], broadcast_buffers=False) |
|
else: |
|
disc = disc.to(device) |
|
if config.compile: |
|
disc = torch.compile(disc, mode=['default', 'reduce-overhead', 'max-autotune'][0]) |
|
adv_criterion = nn.BCELoss() if not config.use_fp16 else nn.BCEWithLogitsLoss() |
|
if config.optimizer == 'AdamW': |
|
optimizer_d = optim.AdamW(params=disc.parameters(), lr=config.lr, weight_decay=1e-2) |
|
elif config.optimizer == 'Adam': |
|
optimizer_d = optim.Adam(params=disc.parameters(), lr=config.lr, weight_decay=0) |
|
lr_scheduler_d = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer_d, |
|
milestones=[lde if lde > 0 else args.epochs + lde + 1 for lde in config.lr_decay_epochs], |
|
gamma=config.lr_decay_rate |
|
) |
|
return optimizer_d, lr_scheduler_d, disc, adv_criterion |
|
|
|
def _train_batch(self, batch): |
|
inputs = batch[0].to(device) |
|
gts = batch[1].to(device) |
|
class_labels = batch[2].to(device) |
|
if config.use_fp16: |
|
with amp.autocast(enabled=config.use_fp16): |
|
scaled_preds, class_preds_lst = self.model(inputs) |
|
if config.out_ref: |
|
(outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds |
|
for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)): |
|
_gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True) |
|
|
|
loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt |
|
|
|
if None in class_preds_lst: |
|
loss_cls = 0. |
|
else: |
|
loss_cls = self.cls_loss(class_preds_lst, class_labels) * 1.0 |
|
self.loss_dict['loss_cls'] = loss_cls.item() |
|
|
|
|
|
loss_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1)) * 1.0 |
|
self.loss_dict['loss_pix'] = loss_pix.item() |
|
|
|
loss = loss_pix + loss_cls |
|
if config.out_ref: |
|
loss = loss + loss_gdt * 1.0 |
|
|
|
if config.lambda_adv_g: |
|
|
|
valid = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(1.0), requires_grad=False).to(device) |
|
adv_loss_g = self.adv_criterion(self.disc(scaled_preds[-1] * inputs), valid) * config.lambda_adv_g |
|
loss += adv_loss_g |
|
self.loss_dict['loss_adv'] = adv_loss_g.item() |
|
self.disc_update_for_odd += 1 |
|
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
scaler.scale(loss).backward() |
|
scaler.step(self.optimizer) |
|
scaler.update() |
|
|
|
if config.lambda_adv_g and self.disc_update_for_odd % 2 == 0: |
|
|
|
fake = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(0.0), requires_grad=False).to(device) |
|
adv_loss_real = self.adv_criterion(self.disc(gts * inputs), valid) |
|
adv_loss_fake = self.adv_criterion(self.disc(scaled_preds[-1].detach() * inputs.detach()), fake) |
|
adv_loss_d = (adv_loss_real + adv_loss_fake) / 2 * config.lambda_adv_d |
|
self.loss_dict['loss_adv_d'] = adv_loss_d.item() |
|
|
|
|
|
|
|
self.optimizer_d.zero_grad() |
|
scaler.scale(adv_loss_d).backward() |
|
scaler.step(self.optimizer_d) |
|
scaler.update() |
|
else: |
|
scaled_preds, class_preds_lst = self.model(inputs) |
|
if config.out_ref: |
|
(outs_gdt_pred, outs_gdt_label), scaled_preds = scaled_preds |
|
for _idx, (_gdt_pred, _gdt_label) in enumerate(zip(outs_gdt_pred, outs_gdt_label)): |
|
_gdt_pred = nn.functional.interpolate(_gdt_pred, size=_gdt_label.shape[2:], mode='bilinear', align_corners=True).sigmoid() |
|
_gdt_label = _gdt_label.sigmoid() |
|
loss_gdt = self.criterion_gdt(_gdt_pred, _gdt_label) if _idx == 0 else self.criterion_gdt(_gdt_pred, _gdt_label) + loss_gdt |
|
|
|
if None in class_preds_lst: |
|
loss_cls = 0. |
|
else: |
|
loss_cls = self.cls_loss(class_preds_lst, class_labels) * 1.0 |
|
self.loss_dict['loss_cls'] = loss_cls.item() |
|
|
|
|
|
loss_pix = self.pix_loss(scaled_preds, torch.clamp(gts, 0, 1)) * 1.0 |
|
self.loss_dict['loss_pix'] = loss_pix.item() |
|
|
|
loss = loss_pix + loss_cls |
|
if config.out_ref: |
|
loss = loss + loss_gdt * 1.0 |
|
|
|
if config.lambda_adv_g: |
|
|
|
valid = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(1.0), requires_grad=False).to(device) |
|
adv_loss_g = self.adv_criterion(self.disc(scaled_preds[-1] * inputs), valid) * config.lambda_adv_g |
|
loss += adv_loss_g |
|
self.loss_dict['loss_adv'] = adv_loss_g.item() |
|
self.disc_update_for_odd += 1 |
|
self.loss_log.update(loss.item(), inputs.size(0)) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
if config.lambda_adv_g and self.disc_update_for_odd % 2 == 0: |
|
|
|
fake = Variable(torch.cuda.FloatTensor(scaled_preds[-1].shape[0], 1).fill_(0.0), requires_grad=False).to(device) |
|
adv_loss_real = self.adv_criterion(self.disc(gts * inputs), valid) |
|
adv_loss_fake = self.adv_criterion(self.disc(scaled_preds[-1].detach() * inputs.detach()), fake) |
|
adv_loss_d = (adv_loss_real + adv_loss_fake) / 2 * config.lambda_adv_d |
|
self.loss_dict['loss_adv_d'] = adv_loss_d.item() |
|
self.optimizer_d.zero_grad() |
|
adv_loss_d.backward() |
|
self.optimizer_d.step() |
|
|
|
def train_epoch(self, epoch): |
|
global logger_loss_idx |
|
self.model.train() |
|
self.loss_dict = {} |
|
if epoch > args.epochs + config.IoU_finetune_last_epochs: |
|
self.pix_loss.lambdas_pix_last['bce'] *= 0 |
|
self.pix_loss.lambdas_pix_last['ssim'] *= 1 |
|
self.pix_loss.lambdas_pix_last['iou'] *= 0.5 |
|
|
|
for batch_idx, batch in enumerate(self.train_loader): |
|
self._train_batch(batch) |
|
|
|
if batch_idx % 20 == 0: |
|
info_progress = 'Epoch[{0}/{1}] Iter[{2}/{3}].'.format(epoch, args.epochs, batch_idx, len(self.train_loader)) |
|
info_loss = 'Training Losses' |
|
for loss_name, loss_value in self.loss_dict.items(): |
|
info_loss += ', {}: {:.3f}'.format(loss_name, loss_value) |
|
logger.info(' '.join((info_progress, info_loss))) |
|
info_loss = '@==Final== Epoch[{0}/{1}] Training Loss: {loss.avg:.3f} '.format(epoch, args.epochs, loss=self.loss_log) |
|
logger.info(info_loss) |
|
|
|
self.lr_scheduler.step() |
|
if config.lambda_adv_g: |
|
self.lr_scheduler_d.step() |
|
return self.loss_log.avg |
|
|
|
def validate_model(self, epoch): |
|
num_image_testset_all = {'DIS-VD': 470, 'DIS-TE1': 500, 'DIS-TE2': 500, 'DIS-TE3': 500, 'DIS-TE4': 500} |
|
num_image_testset = {} |
|
for testset in args.testsets: |
|
if 'DIS-TE' in testset: |
|
num_image_testset[testset] = num_image_testset_all[testset] |
|
weighted_scores = {'f_max': 0, 'f_mean': 0, 'f_wfm': 0, 'sm': 0, 'e_max': 0, 'e_mean': 0, 'mae': 0} |
|
len_all_data_loaders = 0 |
|
self.model.epoch = epoch |
|
for testset, data_loader_test in self.test_loaders.items(): |
|
print('Validating {}...'.format(testset)) |
|
performance_dict = valid( |
|
self.model, |
|
data_loader_test, |
|
pred_dir='.', |
|
method=args.ckpt_dir.split('/')[-1] if args.ckpt_dir.split('/')[-1].strip('.').strip('/') else 'tmp_val', |
|
testset=testset, |
|
only_S_MAE=config.only_S_MAE, |
|
device=device |
|
) |
|
print('Test set: {}:'.format(testset)) |
|
if config.only_S_MAE: |
|
print('Smeasure: {:.4f}, MAE: {:.4f}'.format( |
|
performance_dict['sm'], performance_dict['mae'] |
|
)) |
|
else: |
|
print('Fmax: {:.4f}, Fwfm: {:.4f}, Smeasure: {:.4f}, Emean: {:.4f}, MAE: {:.4f}'.format( |
|
performance_dict['f_max'], performance_dict['f_wfm'], performance_dict['sm'], performance_dict['e_mean'], performance_dict['mae'] |
|
)) |
|
if '-TE' in testset: |
|
for metric in ['sm', 'mae'] if config.only_S_MAE else ['f_max', 'f_mean', 'f_wfm', 'sm', 'e_max', 'e_mean', 'mae']: |
|
weighted_scores[metric] += performance_dict[metric] * len(data_loader_test) |
|
len_all_data_loaders += len(data_loader_test) |
|
print('Weighted Scores:') |
|
for metric, score in weighted_scores.items(): |
|
if score: |
|
print('\t{}: {:.4f}.'.format(metric, score / len_all_data_loaders)) |
|
|
|
|
|
def main(): |
|
|
|
trainer = Trainer( |
|
data_loaders=init_data_loaders(to_be_distributed), |
|
model_opt_lrsch=init_models_optimizers(args.epochs, to_be_distributed) |
|
) |
|
|
|
for epoch in range(epoch_st, args.epochs+1): |
|
train_loss = trainer.train_epoch(epoch) |
|
|
|
|
|
if epoch >= args.epochs - config.save_last and epoch % config.save_step == 0: |
|
torch.save( |
|
trainer.model.module.state_dict() if to_be_distributed else trainer.model.state_dict(), |
|
os.path.join(args.ckpt_dir, 'epoch_{}.pth'.format(epoch)) |
|
) |
|
if config.val_step and epoch >= args.epochs - config.save_last and (args.epochs - epoch) % config.val_step == 0: |
|
if to_be_distributed: |
|
if get_rank() == 0: |
|
print('Validating at rank-{}...'.format(get_rank())) |
|
trainer.validate_model(epoch) |
|
else: |
|
trainer.validate_model(epoch) |
|
if to_be_distributed: |
|
destroy_process_group() |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|