not-lain commited on
Commit
13e108d
1 Parent(s): ab983e2

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +283 -0
train.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from pytorch_lightning import Trainer
9
+ from pytorch_lightning.callbacks import ModelCheckpoint
10
+ from torch.utils.data import DataLoader
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+ from data_loader import create_training_datasets
14
+ from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet \
15
+ , InSPyReNet, InSPyReNet_Res2Net50, InSPyReNet_SwinB
16
+
17
+
18
+ # warnings.filterwarnings("ignore")
19
+
20
+ net_names = ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet", "inspyrnet_res", "inspyrnet_swin"]
21
+
22
+ def get_net(net_name, img_size):
23
+ if net_name == "isnet":
24
+ return ISNetDIS()
25
+ elif net_name == "isnet_is":
26
+ return ISNetDIS()
27
+ elif net_name == "isnet_gt":
28
+ return ISNetGTEncoder()
29
+ elif net_name == "u2net":
30
+ return U2NET_full2()
31
+ elif net_name == "u2netl":
32
+ return U2NET_lite2()
33
+ elif net_name == "modnet":
34
+ return MODNet()
35
+ elif net_name == "inspyrnet_res":
36
+ return InSPyReNet_Res2Net50(base_size=img_size)
37
+ elif net_name == "inspyrnet_swin":
38
+ return InSPyReNet_SwinB(base_size=img_size)
39
+ raise NotImplementedError
40
+
41
+
42
+ def f1_torch(pred, gt):
43
+ # micro F1-score
44
+ pred = pred.float().view(pred.shape[0], -1)
45
+ gt = gt.float().view(gt.shape[0], -1)
46
+ tp1 = torch.sum(pred * gt, dim=1)
47
+ tp_fp1 = torch.sum(pred, dim=1)
48
+ tp_fn1 = torch.sum(gt, dim=1)
49
+ pred = 1 - pred
50
+ gt = 1 - gt
51
+ tp2 = torch.sum(pred * gt, dim=1)
52
+ tp_fp2 = torch.sum(pred, dim=1)
53
+ tp_fn2 = torch.sum(gt, dim=1)
54
+ precision = (tp1 + tp2) / (tp_fp1 + tp_fp2 + 0.0001)
55
+ recall = (tp1 + tp2) / (tp_fn1 + tp_fn2 + 0.0001)
56
+ f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 0.0001)
57
+ return precision, recall, f1
58
+
59
+
60
+ class AnimeSegmentation(pl.LightningModule,
61
+ PyTorchModelHubMixin,
62
+ library_name="anime_segmentation",
63
+ repo_url="https://github.com/SkyTNT/anime-segmentation",
64
+ tags=["image-segmentation"]
65
+ ):
66
+
67
+ def __init__(self, net_name, img_size=None, lr=1e-3):
68
+ super().__init__()
69
+ assert net_name in net_names
70
+ self.img_size = img_size
71
+ self.lr = lr
72
+ self.net = get_net(net_name, img_size)
73
+ if net_name == "isnet_is":
74
+ self.gt_encoder = get_net("isnet_gt", img_size)
75
+ self.gt_encoder.requires_grad_(False)
76
+ else:
77
+ self.gt_encoder = None
78
+
79
+ @classmethod
80
+ def try_load(cls, net_name, ckpt_path, map_location=None, img_size=None):
81
+ state_dict = torch.load(ckpt_path, map_location=map_location)
82
+ if "epoch" in state_dict:
83
+ return cls.load_from_checkpoint(ckpt_path, net_name=net_name, img_size=img_size, map_location=map_location)
84
+ else:
85
+ model = cls(net_name, img_size)
86
+ if any([k.startswith("net.") for k, v in state_dict.items()]):
87
+ model.load_state_dict(state_dict)
88
+ else:
89
+ model.net.load_state_dict(state_dict)
90
+ return model
91
+
92
+ def configure_optimizers(self):
93
+ optimizer = optim.Adam(self.net.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
94
+ return optimizer
95
+
96
+ def forward(self, x):
97
+ if isinstance(self.net, ISNetDIS):
98
+ return self.net(x)[0][0].sigmoid()
99
+ if isinstance(self.net, ISNetGTEncoder):
100
+ return self.net(x)[0][0].sigmoid()
101
+ elif isinstance(self.net, U2NET):
102
+ return self.net(x)[0].sigmoid()
103
+ elif isinstance(self.net, MODNet):
104
+ return self.net(x, True)[2]
105
+ elif isinstance(self.net, InSPyReNet):
106
+ return self.net.forward_inference(x)["pred"]
107
+ raise NotImplementedError
108
+
109
+ def training_step(self, batch, batch_idx):
110
+ images, labels = batch["image"], batch["label"]
111
+ if isinstance(self.net, ISNetDIS):
112
+ ds, dfs = self.net(images)
113
+ loss_args = [ds, dfs, labels]
114
+ if self.gt_encoder is not None:
115
+ fs = self.gt_encoder(labels)[1]
116
+ loss_args.append(fs)
117
+ elif isinstance(self.net, ISNetGTEncoder):
118
+ ds = self.net(labels)[0]
119
+ loss_args = [ds, labels]
120
+ elif isinstance(self.net, U2NET):
121
+ ds = self.net(images)
122
+ loss_args = [ds, labels]
123
+ elif isinstance(self.net, MODNet):
124
+ trimaps = batch["trimap"]
125
+ pred_semantic, pred_detail, pred_matte = self.net(images, False)
126
+ loss_args = [pred_semantic, pred_detail, pred_matte, images, trimaps, labels]
127
+ elif isinstance(self.net, InSPyReNet):
128
+ out = self.net.forward_train(images, labels)
129
+ loss_args = out
130
+ else:
131
+ raise NotImplementedError
132
+
133
+ loss0, loss = self.net.compute_loss(loss_args)
134
+ self.log_dict({"train/loss": loss, "train/loss_tar": loss0})
135
+ return loss
136
+
137
+ def validation_step(self, batch, batch_idx):
138
+ images, labels = batch["image"], batch["label"]
139
+ if isinstance(self.net, ISNetGTEncoder):
140
+ preds = self.forward(labels)
141
+ else:
142
+ preds = self.forward(images)
143
+ pre, rec, f1, = f1_torch(preds.nan_to_num(nan=0, posinf=1, neginf=0), labels)
144
+ mae_m = F.l1_loss(preds, labels, reduction="mean")
145
+ pre_m = pre.mean()
146
+ rec_m = rec.mean()
147
+ f1_m = f1.mean()
148
+ self.log_dict({"val/precision": pre_m, "val/recall": rec_m, "val/f1": f1_m, "val/mae": mae_m}, sync_dist=True)
149
+
150
+
151
+ def get_gt_encoder(train_dataloader, val_dataloader, opt):
152
+ print("---start train ground truth encoder---")
153
+ gt_encoder = AnimeSegmentation("isnet_gt")
154
+ trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator,
155
+ devices=opt.devices, max_epochs=opt.gt_epoch,
156
+ benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step,
157
+ check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step,
158
+ strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
159
+ )
160
+ trainer.fit(gt_encoder, train_dataloader, val_dataloader)
161
+ return gt_encoder.net
162
+
163
+
164
+ def main(opt):
165
+ if not os.path.exists("lightning_logs"):
166
+ os.mkdir("lightning_logs")
167
+
168
+ train_dataset, val_dataset = create_training_datasets(opt.data_dir, opt.fg_dir, opt.bg_dir, opt.img_dir,
169
+ opt.mask_dir, opt.fg_ext, opt.bg_ext, opt.img_ext,
170
+ opt.mask_ext, opt.data_split, opt.img_size,
171
+ with_trimap=opt.net == "modnet",
172
+ cache_ratio=opt.cache, cache_update_epoch=opt.cache_epoch)
173
+
174
+ train_dataloader = DataLoader(train_dataset, batch_size=opt.batch_size_train, shuffle=True, persistent_workers=True,
175
+ num_workers=opt.workers_train, pin_memory=True)
176
+ val_dataloader = DataLoader(val_dataset, batch_size=opt.batch_size_val, shuffle=False, persistent_workers=True,
177
+ num_workers=opt.workers_val, pin_memory=True)
178
+ print("---define model---")
179
+
180
+ if opt.pretrained_ckpt == "":
181
+ anime_seg = AnimeSegmentation(opt.net, opt.img_size)
182
+ else:
183
+ anime_seg = AnimeSegmentation.try_load(opt.net, opt.pretrained_ckpt, "cpu", opt.img_size)
184
+ if not opt.pretrained_ckpt and not opt.resume_ckpt and opt.net == "isnet_is":
185
+ anime_seg.gt_encoder.load_state_dict(get_gt_encoder(train_dataloader, val_dataloader, opt).state_dict())
186
+ anime_seg.lr = opt.lr
187
+
188
+ print("---start train---")
189
+ checkpoint_callback = ModelCheckpoint(monitor='val/f1', mode="max", save_top_k=1, save_last=True,
190
+ auto_insert_metric_name=False, filename="epoch={epoch},f1={val/f1:.4f}")
191
+ trainer = Trainer(precision=32 if opt.fp32 else 16, accelerator=opt.accelerator,
192
+ devices=opt.devices, max_epochs=opt.epoch,
193
+ benchmark=opt.benchmark, accumulate_grad_batches=opt.acc_step,
194
+ check_val_every_n_epoch=opt.val_epoch, log_every_n_steps=opt.log_step,
195
+ strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
196
+ callbacks=[checkpoint_callback])
197
+ trainer.fit(anime_seg, train_dataloader, val_dataloader, ckpt_path=opt.resume_ckpt or None)
198
+
199
+
200
+ if __name__ == "__main__":
201
+ parser = argparse.ArgumentParser()
202
+ # model args
203
+ parser.add_argument('--net', type=str, default='isnet_is',
204
+ choices=net_names,
205
+ help='isnet_is: Train ISNet with intermediate feature supervision, '
206
+ 'isnet: Train ISNet, '
207
+ 'u2net: Train U2Net full, '
208
+ 'u2netl: Train U2Net lite, '
209
+ 'modnet: Train MODNet'
210
+ 'inspyrnet_res: Train InSPyReNet_Res2Net50'
211
+ 'inspyrnet_swin: Train InSPyReNet_SwinB')
212
+ parser.add_argument('--pretrained-ckpt', type=str, default='',
213
+ help='load form pretrained ckpt')
214
+ parser.add_argument('--resume-ckpt', type=str, default='',
215
+ help='resume training from ckpt')
216
+ parser.add_argument('--img-size', type=int, default=1024,
217
+ help='image size for training and validation,'
218
+ '1024 recommend for ISNet,'
219
+ '384 recommend for InSPyReNet'
220
+ '640 recommend for others,')
221
+
222
+ # dataset args
223
+ parser.add_argument('--data-dir', type=str, default='../../dataset/anime-seg',
224
+ help='root dir of dataset')
225
+ parser.add_argument('--fg-dir', type=str, default='fg',
226
+ help='relative dir of foreground')
227
+ parser.add_argument('--bg-dir', type=str, default='bg',
228
+ help='relative dir of background')
229
+ parser.add_argument('--img-dir', type=str, default='imgs',
230
+ help='relative dir of images')
231
+ parser.add_argument('--mask-dir', type=str, default='masks',
232
+ help='relative dir of masks')
233
+ parser.add_argument('--fg-ext', type=str, default='.png',
234
+ help='extension name of foreground')
235
+ parser.add_argument('--bg-ext', type=str, default='.jpg',
236
+ help='extension name of background')
237
+ parser.add_argument('--img-ext', type=str, default='.jpg',
238
+ help='extension name of images')
239
+ parser.add_argument('--mask-ext', type=str, default='.jpg',
240
+ help='extension name of masks')
241
+ parser.add_argument('--data-split', type=float, default=0.95,
242
+ help='split rate for training and validation')
243
+
244
+ # training args
245
+ parser.add_argument('--lr', type=float, default=1e-4,
246
+ help='learning rate')
247
+ parser.add_argument('--epoch', type=int, default=40,
248
+ help='epoch num')
249
+ parser.add_argument('--gt-epoch', type=int, default=4,
250
+ help='epoch for training ground truth encoder when net is isnet_is')
251
+ parser.add_argument('--batch-size-train', type=int, default=2,
252
+ help='batch size for training')
253
+ parser.add_argument('--batch-size-val', type=int, default=2,
254
+ help='batch size for val')
255
+ parser.add_argument('--workers-train', type=int, default=4,
256
+ help='workers num for training dataloader')
257
+ parser.add_argument('--workers-val', type=int, default=4,
258
+ help='workers num for validation dataloader')
259
+ parser.add_argument('--acc-step', type=int, default=4,
260
+ help='gradient accumulation step')
261
+ parser.add_argument('--accelerator', type=str, default="gpu",
262
+ choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"],
263
+ help='accelerator')
264
+ parser.add_argument('--devices', type=int, default=1,
265
+ help='devices num')
266
+ parser.add_argument('--fp32', action='store_true', default=False,
267
+ help='disable mix precision')
268
+ parser.add_argument('--benchmark', action='store_true', default=False,
269
+ help='enable cudnn benchmark')
270
+ parser.add_argument('--log-step', type=int, default=2,
271
+ help='log training loss every n steps')
272
+ parser.add_argument('--val-epoch', type=int, default=1,
273
+ help='valid and save every n epoch')
274
+ parser.add_argument('--cache-epoch', type=int, default=3,
275
+ help='update cache every n epoch')
276
+ parser.add_argument('--cache', type=float, default=0,
277
+ help='ratio (cache to entire training dataset), '
278
+ 'higher values require more memory, set 0 to disable cache')
279
+
280
+ opt = parser.parse_args()
281
+ print(opt)
282
+
283
+ main(opt)