skytnt commited on
Commit
5a76c1a
1 Parent(s): 4659f8d

Upload anime_aesthetic.py

Browse files
Files changed (1) hide show
  1. anime_aesthetic.py +495 -0
anime_aesthetic.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ from copy import deepcopy
6
+
7
+ import numpy as np
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from PIL import Image
14
+ from pytorch_lightning import Trainer
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+ from timm.models.layers import DropPath, trunc_normal_
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from torchvision import transforms
19
+ from torchvision.transforms import functional
20
+
21
+
22
+ # ========= Model =========
23
+
24
+ # copy from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
25
+
26
+
27
+ class LayerNorm(nn.Module):
28
+ """LayerNorm that supports two data formats: channels_last (default) or channels_first.
29
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
30
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
31
+ with shape (batch_size, channels, height, width).
32
+ """
33
+
34
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
37
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
38
+ self.eps = eps
39
+ self.data_format = data_format
40
+ if self.data_format not in ["channels_last", "channels_first"]:
41
+ raise NotImplementedError
42
+ self.normalized_shape = (normalized_shape,)
43
+
44
+ def forward(self, x):
45
+ if self.data_format == "channels_last":
46
+ return F.layer_norm(
47
+ x, self.normalized_shape, self.weight, self.bias, self.eps
48
+ )
49
+ elif self.data_format == "channels_first":
50
+ u = x.mean(1, keepdim=True)
51
+ s = (x - u).pow(2).mean(1, keepdim=True)
52
+ x = (x - u) / torch.sqrt(s + self.eps)
53
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
54
+ return x
55
+
56
+
57
+ class GRN(nn.Module):
58
+ """GRN (Global Response Normalization) layer"""
59
+
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
63
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
64
+
65
+ def forward(self, x):
66
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
67
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
68
+ return self.gamma * (x * Nx) + self.beta + x
69
+
70
+
71
+ class Block(nn.Module):
72
+ """ConvNeXtV2 Block.
73
+
74
+ Args:
75
+ dim (int): Number of input channels.
76
+ drop_path (float): Stochastic depth rate. Default: 0.0
77
+ """
78
+
79
+ def __init__(self, dim, drop_path=0.0):
80
+ super().__init__()
81
+ self.dwconv = nn.Conv2d(
82
+ dim, dim, kernel_size=7, padding=3, groups=dim
83
+ ) # depthwise conv
84
+ self.norm = LayerNorm(dim, eps=1e-6)
85
+ self.pwconv1 = nn.Linear(
86
+ dim, 4 * dim
87
+ ) # pointwise/1x1 convs, implemented with linear layers
88
+ self.act = nn.GELU()
89
+ self.grn = GRN(4 * dim)
90
+ self.pwconv2 = nn.Linear(4 * dim, dim)
91
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
92
+
93
+ def forward(self, x):
94
+ input = x
95
+ x = self.dwconv(x)
96
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
97
+ x = self.norm(x)
98
+ x = self.pwconv1(x)
99
+ x = self.act(x)
100
+ x = self.grn(x)
101
+ x = self.pwconv2(x)
102
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
103
+
104
+ x = input + self.drop_path(x)
105
+ return x
106
+
107
+
108
+ class ConvNeXtV2(nn.Module):
109
+ """ConvNeXt V2
110
+
111
+ Args:
112
+ in_chans (int): Number of input image channels. Default: 3
113
+ num_classes (int): Number of classes for classification head. Default: 1000
114
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
115
+ dims (int[]): Feature dimension at each stage. Default: [96, 192, 384, 768]
116
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
117
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ in_chans=3,
123
+ num_classes=1000,
124
+ depths=[3, 3, 9, 3],
125
+ dims=[96, 192, 384, 768],
126
+ drop_path_rate=0.0,
127
+ head_init_scale=1.0,
128
+ ):
129
+ super().__init__()
130
+ self.depths = depths
131
+ self.downsample_layers = (
132
+ nn.ModuleList()
133
+ ) # stem and 3 intermediate downsampling conv layers
134
+ stem = nn.Sequential(
135
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
136
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
137
+ )
138
+ self.downsample_layers.append(stem)
139
+ for i in range(3):
140
+ downsample_layer = nn.Sequential(
141
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
142
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
143
+ )
144
+ self.downsample_layers.append(downsample_layer)
145
+
146
+ self.stages = (
147
+ nn.ModuleList()
148
+ ) # 4 feature resolution stages, each consisting of multiple residual blocks
149
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
150
+ cur = 0
151
+ for i in range(4):
152
+ stage = nn.Sequential(
153
+ *[
154
+ Block(dim=dims[i], drop_path=dp_rates[cur + j])
155
+ for j in range(depths[i])
156
+ ]
157
+ )
158
+ self.stages.append(stage)
159
+ cur += depths[i]
160
+
161
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
162
+ self.head = nn.Linear(dims[-1], num_classes)
163
+
164
+ self.apply(self._init_weights)
165
+ self.head.weight.data.mul_(head_init_scale)
166
+ self.head.bias.data.mul_(head_init_scale)
167
+
168
+ def _init_weights(self, m):
169
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
170
+ trunc_normal_(m.weight, std=0.02)
171
+ nn.init.constant_(m.bias, 0)
172
+
173
+ def forward_features(self, x):
174
+ for i in range(4):
175
+ x = self.downsample_layers[i](x)
176
+ x = self.stages[i](x)
177
+ return self.norm(
178
+ x.mean([-2, -1])
179
+ ) # global average pooling, (N, C, H, W) -> (N, C)
180
+
181
+ def forward(self, x):
182
+ x = self.forward_features(x)
183
+ x = self.head(x)
184
+ return x
185
+
186
+
187
+ model_cfgs = {
188
+ "atto": [[2, 2, 6, 2], [40, 80, 160, 320]],
189
+ "femto": [[2, 2, 6, 2], [48, 96, 192, 384]],
190
+ "pico": [[2, 2, 6, 2], [64, 128, 256, 512]],
191
+ "nano": [[2, 2, 8, 2], [80, 160, 320, 640]],
192
+ "tiny": [[3, 3, 9, 3], [96, 192, 384, 768]],
193
+ "base": [[3, 3, 27, 3], [128, 256, 512, 1024]],
194
+ "large": [[3, 3, 27, 3], [192, 384, 768, 1536]],
195
+ "huge": [[3, 3, 27, 3], [352, 704, 1408, 2816]],
196
+ }
197
+
198
+
199
+ def convnextv2(cfg_name, **kwargs):
200
+ cfg = model_cfgs[cfg_name]
201
+ model = ConvNeXtV2(depths=cfg[0], dims=cfg[1], **kwargs)
202
+ return model
203
+
204
+
205
+ # ========= Dataset =========
206
+
207
+ EXTENSION = [".png", ".jpg", ".jpeg"]
208
+
209
+
210
+ def file_ext(fname):
211
+ return os.path.splitext(fname)[1].lower()
212
+
213
+
214
+ def rescale_pad(image, output_size, random_pad=False):
215
+ h, w = image.shape[-2:]
216
+ if h != output_size or w != output_size:
217
+ r = min(output_size / h, output_size / w)
218
+ new_h, new_w = int(h * r), int(w * r)
219
+ ph = output_size - new_h
220
+ pw = output_size - new_w
221
+ image = transforms.functional.resize(image, [new_h, new_w])
222
+ image = transforms.functional.pad(
223
+ image, [pw // 2, ph // 2, pw // 2 + pw % 2, ph // 2 + ph % 2], random.uniform(0, 1) if random_pad else 0
224
+ )
225
+ return image
226
+
227
+
228
+ def random_crop(image, min_rate=0.8):
229
+ h, w = image.shape[-2:]
230
+ new_h, new_w = int(h * random.uniform(min_rate, 1)), int(w * random.uniform(min_rate, 1))
231
+ top = np.random.randint(0, h - new_h)
232
+ left = np.random.randint(0, w - new_w)
233
+ image = image[:, top: top + new_h, left: left + new_w]
234
+ return image
235
+
236
+
237
+ class AnimeAestheticDataset(Dataset):
238
+ def __init__(self, path, img_size, xflip=True):
239
+ all_files = {
240
+ os.path.relpath(os.path.join(root, fname), path)
241
+ for root, _dirs, files in os.walk(path)
242
+ for fname in files
243
+ }
244
+ all_images = sorted(
245
+ fname for fname in all_files if file_ext(fname) in EXTENSION
246
+ )
247
+ with open(os.path.join(path, "label.json"), "r", encoding="utf8") as f:
248
+ labels = json.load(f)
249
+ image_list = []
250
+ label_list = []
251
+ for fname in all_images:
252
+ if fname not in labels:
253
+ continue
254
+ image_list.append(fname)
255
+ label_list.append(labels[fname])
256
+ self.path = path
257
+ self.img_size = img_size
258
+ self.xflip = xflip
259
+ self.image_list = image_list
260
+ self.label_list = label_list
261
+
262
+ def __len__(self):
263
+ length = len(self.image_list)
264
+ if self.xflip:
265
+ length *= 2
266
+ return length
267
+
268
+ def __getitem__(self, index):
269
+ real_len = len(self.image_list)
270
+ fname = self.image_list[index % real_len]
271
+ label = self.label_list[index % real_len]
272
+ image = Image.open(os.path.join(self.path, fname)).convert("RGB")
273
+ image = transforms.functional.to_tensor(image)
274
+ image = random_crop(image, 0.8)
275
+ image = rescale_pad(image, self.img_size, True)
276
+ if index // real_len != 0:
277
+ image = transforms.functional.hflip(image)
278
+ label = torch.tensor([label], dtype=torch.float32)
279
+ return image, label
280
+
281
+
282
+ # ========= Train =========
283
+
284
+
285
+ class AnimeAesthetic(pl.LightningModule):
286
+ def __init__(self, cfg: str, drop_path_rate=0.0, ema_decay=0):
287
+ super().__init__()
288
+ self.net = convnextv2(cfg, in_chans=3, num_classes=1, drop_path_rate=drop_path_rate)
289
+ self.ema_decay = ema_decay
290
+ self.ema = None
291
+ if ema_decay > 0:
292
+ self.ema = deepcopy(self.net)
293
+ self.ema.requires_grad_(False)
294
+
295
+ def configure_optimizers(self):
296
+ optimizer = optim.Adam(
297
+ self.net.parameters(),
298
+ lr=0.001,
299
+ betas=(0.9, 0.999),
300
+ eps=1e-08,
301
+ weight_decay=0,
302
+ )
303
+ return optimizer
304
+
305
+ def forward(self, x, use_ema=False):
306
+ x = (x - 0.5) / 0.5
307
+ net = self.ema if use_ema else self.net
308
+ return net(x)
309
+
310
+ def training_step(self, batch, batch_idx):
311
+ images, labels = batch
312
+ loss = F.mse_loss(self.forward(images, False), labels)
313
+ self.log_dict({"train/loss": loss})
314
+ return loss
315
+
316
+ def validation_step(self, batch, batch_idx):
317
+ images, labels = batch
318
+ mae = F.l1_loss(self.forward(images, False), labels)
319
+ logs = {"val/mae": mae}
320
+ if self.ema is not None:
321
+ mae_ema = F.l1_loss(self.forward(images, True), labels)
322
+ logs["val/mae_ema"] = mae_ema
323
+ self.log_dict(logs, sync_dist=True)
324
+
325
+ def on_train_batch_end(self, outputs, batch, batch_idx):
326
+ if self.ema is not None:
327
+ with torch.no_grad():
328
+ for ema_v, model_v in zip(
329
+ self.ema.state_dict().values(), self.net.state_dict().values()
330
+ ):
331
+ ema_v.copy_(
332
+ self.ema_decay * ema_v + (1.0 - self.ema_decay) * model_v
333
+ )
334
+
335
+
336
+ def main(opt):
337
+ if not os.path.exists("lightning_logs"):
338
+ os.mkdir("lightning_logs")
339
+ torch.manual_seed(0)
340
+ np.random.seed(0)
341
+ print("---load dataset---")
342
+ full_dataset = AnimeAestheticDataset(opt.data, opt.img_size)
343
+ full_dataset_len = len(full_dataset)
344
+ train_dataset_len = int(full_dataset_len * opt.data_split)
345
+ val_dataset_len = full_dataset_len - train_dataset_len
346
+ train_dataset, val_dataset = torch.utils.data.random_split(
347
+ full_dataset, [train_dataset_len, val_dataset_len]
348
+ )
349
+ train_dataloader = DataLoader(
350
+ train_dataset,
351
+ batch_size=opt.batch_size_train,
352
+ shuffle=True,
353
+ persistent_workers=True,
354
+ num_workers=opt.workers_train,
355
+ pin_memory=True,
356
+ )
357
+ val_dataloader = DataLoader(
358
+ val_dataset,
359
+ batch_size=opt.batch_size_val,
360
+ shuffle=False,
361
+ persistent_workers=True,
362
+ num_workers=opt.workers_val,
363
+ pin_memory=True,
364
+ )
365
+ print(f"train: {len(train_dataset)}")
366
+ print(f"val: {len(val_dataset)}")
367
+ print("---define model---")
368
+ if opt.resume != "":
369
+ anime_aesthetic = AnimeAesthetic.load_from_checkpoint(
370
+ opt.resume, cfg=opt.cfg, drop_path_rate=opt.drop_path, ema_decay=opt.ema_decay
371
+ )
372
+ else:
373
+ anime_aesthetic = AnimeAesthetic(cfg=opt.cfg, drop_path_rate=opt.drop_path, ema_decay=opt.ema_decay)
374
+
375
+ print("---start train---")
376
+
377
+ checkpoint_callback = ModelCheckpoint(
378
+ monitor="val/mae",
379
+ mode="min",
380
+ save_top_k=1,
381
+ save_last=True,
382
+ auto_insert_metric_name=False,
383
+ filename="epoch={epoch},mae={val/mae:.4f}",
384
+ )
385
+ callbacks = [checkpoint_callback]
386
+ if opt.ema_decay > 0:
387
+ checkpoint_ema_callback = ModelCheckpoint(
388
+ monitor="val/mae_ema",
389
+ mode="min",
390
+ save_top_k=1,
391
+ save_last=False,
392
+ auto_insert_metric_name=False,
393
+ filename="epoch={epoch},mae-ema={val/mae_ema:.4f}",
394
+ )
395
+ callbacks.append(checkpoint_ema_callback)
396
+ trainer = Trainer(
397
+ precision=32 if opt.fp32 else 16,
398
+ accelerator=opt.accelerator,
399
+ devices=opt.devices,
400
+ max_epochs=opt.epoch,
401
+ benchmark=opt.benchmark,
402
+ accumulate_grad_batches=opt.acc_step,
403
+ val_check_interval=opt.val_epoch,
404
+ log_every_n_steps=opt.log_step,
405
+ strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None,
406
+ callbacks=callbacks,
407
+ )
408
+ trainer.fit(anime_aesthetic, train_dataloader, val_dataloader)
409
+
410
+
411
+ if __name__ == "__main__":
412
+ parser = argparse.ArgumentParser()
413
+ # model args
414
+ parser.add_argument(
415
+ "--cfg",
416
+ type=str,
417
+ default="tiny",
418
+ choices=list(model_cfgs.keys()),
419
+ help="model configure",
420
+ )
421
+ parser.add_argument(
422
+ "--resume", type=str, default="", help="resume training from ckpt"
423
+ )
424
+ parser.add_argument(
425
+ "--img-size",
426
+ type=int,
427
+ default=768,
428
+ help="image size for training and validation",
429
+ )
430
+
431
+ # dataset args
432
+ parser.add_argument(
433
+ "--data", type=str, default="./data", help="dataset path"
434
+ )
435
+ parser.add_argument(
436
+ "--data-split",
437
+ type=float,
438
+ default=0.9995,
439
+ help="split rate for training and validation",
440
+ )
441
+
442
+ # training args
443
+ parser.add_argument("--epoch", type=int, default=100, help="epoch num")
444
+ parser.add_argument(
445
+ "--batch-size-train", type=int, default=16, help="batch size for training"
446
+ )
447
+ parser.add_argument(
448
+ "--batch-size-val", type=int, default=2, help="batch size for val"
449
+ )
450
+ parser.add_argument(
451
+ "--workers-train",
452
+ type=int,
453
+ default=4,
454
+ help="workers num for training dataloader",
455
+ )
456
+ parser.add_argument(
457
+ "--workers-val",
458
+ type=int,
459
+ default=4,
460
+ help="workers num for validation dataloader",
461
+ )
462
+ parser.add_argument(
463
+ "--acc-step", type=int, default=8, help="gradient accumulation step"
464
+ )
465
+ parser.add_argument(
466
+ "--drop-path", type=float, default=0.1, help="Drop path rate"
467
+ )
468
+ parser.add_argument(
469
+ "--ema-decay", type=float, default=0.9999, help="use ema if ema-decay > 0"
470
+ )
471
+ parser.add_argument(
472
+ "--accelerator",
473
+ type=str,
474
+ default="gpu",
475
+ choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"],
476
+ help="accelerator",
477
+ )
478
+ parser.add_argument("--devices", type=int, default=4, help="devices num")
479
+ parser.add_argument(
480
+ "--fp32", action="store_true", default=False, help="disable mix precision"
481
+ )
482
+ parser.add_argument(
483
+ "--benchmark", action="store_true", default=True, help="enable cudnn benchmark"
484
+ )
485
+ parser.add_argument(
486
+ "--log-step", type=int, default=2, help="log training loss every n steps"
487
+ )
488
+ parser.add_argument(
489
+ "--val-epoch", type=int, default=0.1, help="valid and save every n epoch"
490
+ )
491
+
492
+ opt = parser.parse_args()
493
+ print(opt)
494
+
495
+ main(opt)