import argparse import os import pytorch_lightning as pl import torch import torch.nn.functional as F import torch.optim as optim from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader from huggingface_hub import PyTorchModelHubMixin from data_loader import create_training_datasets from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet \ , InSPyReNet, InSPyReNet_Res2Net50, InSPyReNet_SwinB # warnings.filterwarnings("ignore") net_names = ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet", "inspyrnet_res", "inspyrnet_swin"] def get_net(net_name, img_size): if net_name == "isnet": return ISNetDIS() elif net_name == "isnet_is": return ISNetDIS() elif net_name == "isnet_gt": return ISNetGTEncoder() elif net_name == "u2net": return U2NET_full2() elif net_name == "u2netl": return U2NET_lite2() elif net_name == "modnet": return MODNet() elif net_name == "inspyrnet_res": return InSPyReNet_Res2Net50(base_size=img_size) elif net_name == "inspyrnet_swin": return InSPyReNet_SwinB(base_size=img_size) raise NotImplementedError def f1_torch(pred, gt): # micro F1-score pred = pred.float().view(pred.shape[0], -1) gt = gt.float().view(gt.shape[0], -1) tp1 = torch.sum(pred * gt, dim=1) tp_fp1 = torch.sum(pred, dim=1) tp_fn1 = torch.sum(gt, dim=1) pred = 1 - pred gt = 1 - gt tp2 = torch.sum(pred * gt, dim=1) tp_fp2 = torch.sum(pred, dim=1) tp_fn2 = torch.sum(gt, dim=1) precision = (tp1 + tp2) / (tp_fp1 + tp_fp2 + 0.0001) recall = (tp1 + tp2) / (tp_fn1 + tp_fn2 + 0.0001) f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 0.0001) return precision, recall, f1 class AnimeSegmentation(pl.LightningModule, PyTorchModelHubMixin, library_name="anime_segmentation", repo_url="https://github.com/SkyTNT/anime-segmentation", tags=["image-segmentation"] ): def __init__(self, net_name, img_size=None, lr=1e-3): super().__init__() assert net_name in net_names self.img_size = img_size self.lr = lr self.net = get_net(net_name, img_size) if net_name == "isnet_is": self.gt_encoder = get_net("isnet_gt", img_size) self.gt_encoder.requires_grad_(False) else: self.gt_encoder = None @classmethod def try_load(cls, net_name, ckpt_path, map_location=None, img_size=None): state_dict = torch.load(ckpt_path, map_location=map_location) if "epoch" in state_dict: return cls.load_from_checkpoint(ckpt_path, net_name=net_name, img_size=img_size, map_location=map_location) else: model = cls(net_name, img_size) if any([k.startswith("net.") for k, v in state_dict.items()]): model.load_state_dict(state_dict) else: model.net.load_state_dict(state_dict) return model def configure_optimizers(self): optimizer = optim.Adam(self.net.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) return optimizer def forward(self, x): if isinstance(self.net, ISNetDIS): return self.net(x)[0][0].sigmoid() if isinstance(self.net, ISNetGTEncoder): return self.net(x)[0][0].sigmoid() elif isinstance(self.net, U2NET): return self.net(x)[0].sigmoid() elif isinstance(self.net, MODNet): return self.net(x, True)[2] elif isinstance(self.net, InSPyReNet): return self.net.forward_inference(x)["pred"] raise NotImplementedError def training_step(self, batch, batch_idx): images, labels = batch["image"], batch["label"] if isinstance(self.net, ISNetDIS): ds, dfs = self.net(images) loss_args = [ds, dfs, labels] if self.gt_encoder is not None: fs = self.gt_encoder(labels)[1] loss_args.append(fs) elif isinstance(self.net, ISNetGTEncoder): ds = self.net(labels)[0] loss_args = [ds, labels] elif isinstance(self.net, U2NET): ds = self.net(images) loss_args = [ds, labels] elif isinstance(self.net, MODNet): trimaps = batch["trimap"] pred_semantic, pred_detail, pred_matte = self.net(images, False) loss_args = [pred_semantic, pred_detail, pred_matte, images, trimaps, labels] elif isinstance(self.net, InSPyReNet): out = self.net.forward_train(images, labels) loss_args = out else: raise NotImplementedError loss0, loss = self.net.compute_loss(loss_args) self.log_dict({"train/loss": loss, "train/loss_tar": loss0}) return loss def validation_step(self, batch, batch_idx): images, labels = batch["image"], batch["label"] if isinstance(self.net, ISNetGTEncoder): preds = self.forward(labels) else: preds = self.forward(images) pre, rec, f1, = f1_torch(preds.nan_to_num(nan=0, posinf=1, neginf=0), labels) mae_m = F.l1_loss(preds, labels, reduction="mean") pre_m = pre.mean() rec_m = rec.mean() f1_m = f1.mean() self.log_dict({"val/precision": pre_m, "val/recall": rec_m, "val/f1": f1_m, "val/mae": mae_m}, sync_dist=True) def get_gt_encoder(train_dataloader, val_dataloader, opt): print("---start train ground truth encoder---") gt_encoder = AnimeSegmentation("isnet_gt") trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator, devices=opt.devices, max_epochs=opt.gt_epoch, benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step, check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step, strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None, ) trainer.fit(gt_encoder, train_dataloader, val_dataloader) return gt_encoder.net def main(opt): if not os.path.exists("lightning_logs"): os.mkdir("lightning_logs") train_dataset, val_dataset = create_training_datasets(opt.data_dir, opt.fg_dir, opt.bg_dir, opt.img_dir, opt.mask_dir, opt.fg_ext, opt.bg_ext, opt.img_ext, opt.mask_ext, opt.data_split, opt.img_size, with_trimap=opt.net == "modnet", cache_ratio=opt.cache, cache_update_epoch=opt.cache_epoch) train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size_train, shuffle=True, persistent_workers=True, num_workers=opt.workers_train, pin_memory=True) val_dataloader = DataLoader(val_dataset, batch_size=opt.batch_size_val, shuffle=False, persistent_workers=True, num_workers=opt.workers_val, pin_memory=True) print("---define model---") if opt.pretrained_ckpt == "": anime_seg = AnimeSegmentation(opt.net, opt.img_size) else: anime_seg = AnimeSegmentation.try_load(opt.net, opt.pretrained_ckpt, "cpu", opt.img_size) if not opt.pretrained_ckpt and not opt.resume_ckpt and opt.net == "isnet_is": anime_seg.gt_encoder.load_state_dict(get_gt_encoder(train_dataloader, val_dataloader, opt).state_dict()) anime_seg.lr = opt.lr print("---start train---") checkpoint_callback = ModelCheckpoint(monitor='val/f1', mode="max", save_top_k=1, save_last=True, auto_insert_metric_name=False, filename="epoch={epoch},f1={val/f1:.4f}") trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator, devices=opt.devices, max_epochs=opt.epoch, benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step, check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step, strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None, callbacks=[checkpoint_callback]) trainer.fit(anime_seg, train_dataloader, val_dataloader, ckpt_path=opt.resume_ckpt or None) if __name__ == "__main__": parser = argparse.ArgumentParser() # model args parser.add_argument('--net', type=str, default='isnet_is', choices=net_names, help='isnet_is: Train ISNet with intermediate feature supervision, ' 'isnet: Train ISNet, ' 'u2net: Train U2Net full, ' 'u2netl: Train U2Net lite, ' 'modnet: Train MODNet' 'inspyrnet_res: Train InSPyReNet_Res2Net50' 'inspyrnet_swin: Train InSPyReNet_SwinB') parser.add_argument('--pretrained-ckpt', type=str, default='', help='load form pretrained ckpt') parser.add_argument('--resume-ckpt', type=str, default='', help='resume training from ckpt') parser.add_argument('--img-size', type=int, default=1024, help='image size for training and validation,' '1024 recommend for ISNet,' '384 recommend for InSPyReNet' '640 recommend for others,') # dataset args parser.add_argument('--data-dir', type=str, default='../../dataset/anime-seg', help='root dir of dataset') parser.add_argument('--fg-dir', type=str, default='fg', help='relative dir of foreground') parser.add_argument('--bg-dir', type=str, default='bg', help='relative dir of background') parser.add_argument('--img-dir', type=str, default='imgs', help='relative dir of images') parser.add_argument('--mask-dir', type=str, default='masks', help='relative dir of masks') parser.add_argument('--fg-ext', type=str, default='.png', help='extension name of foreground') parser.add_argument('--bg-ext', type=str, default='.jpg', help='extension name of background') parser.add_argument('--img-ext', type=str, default='.jpg', help='extension name of images') parser.add_argument('--mask-ext', type=str, default='.jpg', help='extension name of masks') parser.add_argument('--data-split', type=float, default=0.95, help='split rate for training and validation') # training args parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--epoch', type=int, default=40, help='epoch num') parser.add_argument('--gt-epoch', type=int, default=4, help='epoch for training ground truth encoder when net is isnet_is') parser.add_argument('--batch-size-train', type=int, default=2, help='batch size for training') parser.add_argument('--batch-size-val', type=int, default=2, help='batch size for val') parser.add_argument('--workers-train', type=int, default=4, help='workers num for training dataloader') parser.add_argument('--workers-val', type=int, default=4, help='workers num for validation dataloader') parser.add_argument('--acc-step', type=int, default=4, help='gradient accumulation step') parser.add_argument('--accelerator', type=str, default="gpu", choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"], help='accelerator') parser.add_argument('--devices', type=int, default=1, help='devices num') parser.add_argument('--fp32', action='store_true', default=False, help='disable mix precision') parser.add_argument('--benchmark', action='store_true', default=False, help='enable cudnn benchmark') parser.add_argument('--log-step', type=int, default=2, help='log training loss every n steps') parser.add_argument('--val-epoch', type=int, default=1, help='valid and save every n epoch') parser.add_argument('--cache-epoch', type=int, default=3, help='update cache every n epoch') parser.add_argument('--cache', type=float, default=0, help='ratio (cache to entire training dataset), ' 'higher values require more memory, set 0 to disable cache') opt = parser.parse_args() print(opt) main(opt)