Create train.py
Browse files
train.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
from pytorch_lightning import Trainer
|
9 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
12 |
+
|
13 |
+
from data_loader import create_training_datasets
|
14 |
+
from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet \
|
15 |
+
, InSPyReNet, InSPyReNet_Res2Net50, InSPyReNet_SwinB
|
16 |
+
|
17 |
+
|
18 |
+
# warnings.filterwarnings("ignore")
|
19 |
+
|
20 |
+
net_names = ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet", "inspyrnet_res", "inspyrnet_swin"]
|
21 |
+
|
22 |
+
def get_net(net_name, img_size):
|
23 |
+
if net_name == "isnet":
|
24 |
+
return ISNetDIS()
|
25 |
+
elif net_name == "isnet_is":
|
26 |
+
return ISNetDIS()
|
27 |
+
elif net_name == "isnet_gt":
|
28 |
+
return ISNetGTEncoder()
|
29 |
+
elif net_name == "u2net":
|
30 |
+
return U2NET_full2()
|
31 |
+
elif net_name == "u2netl":
|
32 |
+
return U2NET_lite2()
|
33 |
+
elif net_name == "modnet":
|
34 |
+
return MODNet()
|
35 |
+
elif net_name == "inspyrnet_res":
|
36 |
+
return InSPyReNet_Res2Net50(base_size=img_size)
|
37 |
+
elif net_name == "inspyrnet_swin":
|
38 |
+
return InSPyReNet_SwinB(base_size=img_size)
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
|
42 |
+
def f1_torch(pred, gt):
|
43 |
+
# micro F1-score
|
44 |
+
pred = pred.float().view(pred.shape[0], -1)
|
45 |
+
gt = gt.float().view(gt.shape[0], -1)
|
46 |
+
tp1 = torch.sum(pred * gt, dim=1)
|
47 |
+
tp_fp1 = torch.sum(pred, dim=1)
|
48 |
+
tp_fn1 = torch.sum(gt, dim=1)
|
49 |
+
pred = 1 - pred
|
50 |
+
gt = 1 - gt
|
51 |
+
tp2 = torch.sum(pred * gt, dim=1)
|
52 |
+
tp_fp2 = torch.sum(pred, dim=1)
|
53 |
+
tp_fn2 = torch.sum(gt, dim=1)
|
54 |
+
precision = (tp1 + tp2) / (tp_fp1 + tp_fp2 + 0.0001)
|
55 |
+
recall = (tp1 + tp2) / (tp_fn1 + tp_fn2 + 0.0001)
|
56 |
+
f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 0.0001)
|
57 |
+
return precision, recall, f1
|
58 |
+
|
59 |
+
|
60 |
+
class AnimeSegmentation(pl.LightningModule,
|
61 |
+
PyTorchModelHubMixin,
|
62 |
+
library_name="anime_segmentation",
|
63 |
+
repo_url="https://github.com/SkyTNT/anime-segmentation",
|
64 |
+
tags=["image-segmentation"]
|
65 |
+
):
|
66 |
+
|
67 |
+
def __init__(self, net_name, img_size=None, lr=1e-3):
|
68 |
+
super().__init__()
|
69 |
+
assert net_name in net_names
|
70 |
+
self.img_size = img_size
|
71 |
+
self.lr = lr
|
72 |
+
self.net = get_net(net_name, img_size)
|
73 |
+
if net_name == "isnet_is":
|
74 |
+
self.gt_encoder = get_net("isnet_gt", img_size)
|
75 |
+
self.gt_encoder.requires_grad_(False)
|
76 |
+
else:
|
77 |
+
self.gt_encoder = None
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def try_load(cls, net_name, ckpt_path, map_location=None, img_size=None):
|
81 |
+
state_dict = torch.load(ckpt_path, map_location=map_location)
|
82 |
+
if "epoch" in state_dict:
|
83 |
+
return cls.load_from_checkpoint(ckpt_path, net_name=net_name, img_size=img_size, map_location=map_location)
|
84 |
+
else:
|
85 |
+
model = cls(net_name, img_size)
|
86 |
+
if any([k.startswith("net.") for k, v in state_dict.items()]):
|
87 |
+
model.load_state_dict(state_dict)
|
88 |
+
else:
|
89 |
+
model.net.load_state_dict(state_dict)
|
90 |
+
return model
|
91 |
+
|
92 |
+
def configure_optimizers(self):
|
93 |
+
optimizer = optim.Adam(self.net.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
94 |
+
return optimizer
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
if isinstance(self.net, ISNetDIS):
|
98 |
+
return self.net(x)[0][0].sigmoid()
|
99 |
+
if isinstance(self.net, ISNetGTEncoder):
|
100 |
+
return self.net(x)[0][0].sigmoid()
|
101 |
+
elif isinstance(self.net, U2NET):
|
102 |
+
return self.net(x)[0].sigmoid()
|
103 |
+
elif isinstance(self.net, MODNet):
|
104 |
+
return self.net(x, True)[2]
|
105 |
+
elif isinstance(self.net, InSPyReNet):
|
106 |
+
return self.net.forward_inference(x)["pred"]
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
def training_step(self, batch, batch_idx):
|
110 |
+
images, labels = batch["image"], batch["label"]
|
111 |
+
if isinstance(self.net, ISNetDIS):
|
112 |
+
ds, dfs = self.net(images)
|
113 |
+
loss_args = [ds, dfs, labels]
|
114 |
+
if self.gt_encoder is not None:
|
115 |
+
fs = self.gt_encoder(labels)[1]
|
116 |
+
loss_args.append(fs)
|
117 |
+
elif isinstance(self.net, ISNetGTEncoder):
|
118 |
+
ds = self.net(labels)[0]
|
119 |
+
loss_args = [ds, labels]
|
120 |
+
elif isinstance(self.net, U2NET):
|
121 |
+
ds = self.net(images)
|
122 |
+
loss_args = [ds, labels]
|
123 |
+
elif isinstance(self.net, MODNet):
|
124 |
+
trimaps = batch["trimap"]
|
125 |
+
pred_semantic, pred_detail, pred_matte = self.net(images, False)
|
126 |
+
loss_args = [pred_semantic, pred_detail, pred_matte, images, trimaps, labels]
|
127 |
+
elif isinstance(self.net, InSPyReNet):
|
128 |
+
out = self.net.forward_train(images, labels)
|
129 |
+
loss_args = out
|
130 |
+
else:
|
131 |
+
raise NotImplementedError
|
132 |
+
|
133 |
+
loss0, loss = self.net.compute_loss(loss_args)
|
134 |
+
self.log_dict({"train/loss": loss, "train/loss_tar": loss0})
|
135 |
+
return loss
|
136 |
+
|
137 |
+
def validation_step(self, batch, batch_idx):
|
138 |
+
images, labels = batch["image"], batch["label"]
|
139 |
+
if isinstance(self.net, ISNetGTEncoder):
|
140 |
+
preds = self.forward(labels)
|
141 |
+
else:
|
142 |
+
preds = self.forward(images)
|
143 |
+
pre, rec, f1, = f1_torch(preds.nan_to_num(nan=0, posinf=1, neginf=0), labels)
|
144 |
+
mae_m = F.l1_loss(preds, labels, reduction="mean")
|
145 |
+
pre_m = pre.mean()
|
146 |
+
rec_m = rec.mean()
|
147 |
+
f1_m = f1.mean()
|
148 |
+
self.log_dict({"val/precision": pre_m, "val/recall": rec_m, "val/f1": f1_m, "val/mae": mae_m}, sync_dist=True)
|
149 |
+
|
150 |
+
|
151 |
+
def get_gt_encoder(train_dataloader, val_dataloader, opt):
|
152 |
+
print("---start train ground truth encoder---")
|
153 |
+
gt_encoder = AnimeSegmentation("isnet_gt")
|
154 |
+
trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator,
|
155 |
+
devices=opt.devices, max_epochs=opt.gt_epoch,
|
156 |
+
benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step,
|
157 |
+
check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step,
|
158 |
+
strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
|
159 |
+
)
|
160 |
+
trainer.fit(gt_encoder, train_dataloader, val_dataloader)
|
161 |
+
return gt_encoder.net
|
162 |
+
|
163 |
+
|
164 |
+
def main(opt):
|
165 |
+
if not os.path.exists("lightning_logs"):
|
166 |
+
os.mkdir("lightning_logs")
|
167 |
+
|
168 |
+
train_dataset, val_dataset = create_training_datasets(opt.data_dir, opt.fg_dir, opt.bg_dir, opt.img_dir,
|
169 |
+
opt.mask_dir, opt.fg_ext, opt.bg_ext, opt.img_ext,
|
170 |
+
opt.mask_ext, opt.data_split, opt.img_size,
|
171 |
+
with_trimap=opt.net == "modnet",
|
172 |
+
cache_ratio=opt.cache, cache_update_epoch=opt.cache_epoch)
|
173 |
+
|
174 |
+
train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size_train, shuffle=True, persistent_workers=True,
|
175 |
+
num_workers=opt.workers_train, pin_memory=True)
|
176 |
+
val_dataloader = DataLoader(val_dataset, batch_size=opt.batch_size_val, shuffle=False, persistent_workers=True,
|
177 |
+
num_workers=opt.workers_val, pin_memory=True)
|
178 |
+
print("---define model---")
|
179 |
+
|
180 |
+
if opt.pretrained_ckpt == "":
|
181 |
+
anime_seg = AnimeSegmentation(opt.net, opt.img_size)
|
182 |
+
else:
|
183 |
+
anime_seg = AnimeSegmentation.try_load(opt.net, opt.pretrained_ckpt, "cpu", opt.img_size)
|
184 |
+
if not opt.pretrained_ckpt and not opt.resume_ckpt and opt.net == "isnet_is":
|
185 |
+
anime_seg.gt_encoder.load_state_dict(get_gt_encoder(train_dataloader, val_dataloader, opt).state_dict())
|
186 |
+
anime_seg.lr = opt.lr
|
187 |
+
|
188 |
+
print("---start train---")
|
189 |
+
checkpoint_callback = ModelCheckpoint(monitor='val/f1', mode="max", save_top_k=1, save_last=True,
|
190 |
+
auto_insert_metric_name=False, filename="epoch={epoch},f1={val/f1:.4f}")
|
191 |
+
trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator,
|
192 |
+
devices=opt.devices, max_epochs=opt.epoch,
|
193 |
+
benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step,
|
194 |
+
check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step,
|
195 |
+
strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
|
196 |
+
callbacks=[checkpoint_callback])
|
197 |
+
trainer.fit(anime_seg, train_dataloader, val_dataloader, ckpt_path=opt.resume_ckpt or None)
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
parser = argparse.ArgumentParser()
|
202 |
+
# model args
|
203 |
+
parser.add_argument('--net', type=str, default='isnet_is',
|
204 |
+
choices=net_names,
|
205 |
+
help='isnet_is: Train ISNet with intermediate feature supervision, '
|
206 |
+
'isnet: Train ISNet, '
|
207 |
+
'u2net: Train U2Net full, '
|
208 |
+
'u2netl: Train U2Net lite, '
|
209 |
+
'modnet: Train MODNet'
|
210 |
+
'inspyrnet_res: Train InSPyReNet_Res2Net50'
|
211 |
+
'inspyrnet_swin: Train InSPyReNet_SwinB')
|
212 |
+
parser.add_argument('--pretrained-ckpt', type=str, default='',
|
213 |
+
help='load form pretrained ckpt')
|
214 |
+
parser.add_argument('--resume-ckpt', type=str, default='',
|
215 |
+
help='resume training from ckpt')
|
216 |
+
parser.add_argument('--img-size', type=int, default=1024,
|
217 |
+
help='image size for training and validation,'
|
218 |
+
'1024 recommend for ISNet,'
|
219 |
+
'384 recommend for InSPyReNet'
|
220 |
+
'640 recommend for others,')
|
221 |
+
|
222 |
+
# dataset args
|
223 |
+
parser.add_argument('--data-dir', type=str, default='../../dataset/anime-seg',
|
224 |
+
help='root dir of dataset')
|
225 |
+
parser.add_argument('--fg-dir', type=str, default='fg',
|
226 |
+
help='relative dir of foreground')
|
227 |
+
parser.add_argument('--bg-dir', type=str, default='bg',
|
228 |
+
help='relative dir of background')
|
229 |
+
parser.add_argument('--img-dir', type=str, default='imgs',
|
230 |
+
help='relative dir of images')
|
231 |
+
parser.add_argument('--mask-dir', type=str, default='masks',
|
232 |
+
help='relative dir of masks')
|
233 |
+
parser.add_argument('--fg-ext', type=str, default='.png',
|
234 |
+
help='extension name of foreground')
|
235 |
+
parser.add_argument('--bg-ext', type=str, default='.jpg',
|
236 |
+
help='extension name of background')
|
237 |
+
parser.add_argument('--img-ext', type=str, default='.jpg',
|
238 |
+
help='extension name of images')
|
239 |
+
parser.add_argument('--mask-ext', type=str, default='.jpg',
|
240 |
+
help='extension name of masks')
|
241 |
+
parser.add_argument('--data-split', type=float, default=0.95,
|
242 |
+
help='split rate for training and validation')
|
243 |
+
|
244 |
+
# training args
|
245 |
+
parser.add_argument('--lr', type=float, default=1e-4,
|
246 |
+
help='learning rate')
|
247 |
+
parser.add_argument('--epoch', type=int, default=40,
|
248 |
+
help='epoch num')
|
249 |
+
parser.add_argument('--gt-epoch', type=int, default=4,
|
250 |
+
help='epoch for training ground truth encoder when net is isnet_is')
|
251 |
+
parser.add_argument('--batch-size-train', type=int, default=2,
|
252 |
+
help='batch size for training')
|
253 |
+
parser.add_argument('--batch-size-val', type=int, default=2,
|
254 |
+
help='batch size for val')
|
255 |
+
parser.add_argument('--workers-train', type=int, default=4,
|
256 |
+
help='workers num for training dataloader')
|
257 |
+
parser.add_argument('--workers-val', type=int, default=4,
|
258 |
+
help='workers num for validation dataloader')
|
259 |
+
parser.add_argument('--acc-step', type=int, default=4,
|
260 |
+
help='gradient accumulation step')
|
261 |
+
parser.add_argument('--accelerator', type=str, default="gpu",
|
262 |
+
choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"],
|
263 |
+
help='accelerator')
|
264 |
+
parser.add_argument('--devices', type=int, default=1,
|
265 |
+
help='devices num')
|
266 |
+
parser.add_argument('--fp32', action='store_true', default=False,
|
267 |
+
help='disable mix precision')
|
268 |
+
parser.add_argument('--benchmark', action='store_true', default=False,
|
269 |
+
help='enable cudnn benchmark')
|
270 |
+
parser.add_argument('--log-step', type=int, default=2,
|
271 |
+
help='log training loss every n steps')
|
272 |
+
parser.add_argument('--val-epoch', type=int, default=1,
|
273 |
+
help='valid and save every n epoch')
|
274 |
+
parser.add_argument('--cache-epoch', type=int, default=3,
|
275 |
+
help='update cache every n epoch')
|
276 |
+
parser.add_argument('--cache', type=float, default=0,
|
277 |
+
help='ratio (cache to entire training dataset), '
|
278 |
+
'higher values require more memory, set 0 to disable cache')
|
279 |
+
|
280 |
+
opt = parser.parse_args()
|
281 |
+
print(opt)
|
282 |
+
|
283 |
+
main(opt)
|