|
import argparse |
|
import json |
|
import os |
|
import random |
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from PIL import Image |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from timm.models.layers import DropPath, trunc_normal_ |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
from torchvision.transforms import functional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
"""LayerNorm that supports two data formats: channels_last (default) or channels_first. |
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
|
with shape (batch_size, channels, height, width). |
|
""" |
|
|
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
self.eps = eps |
|
self.data_format = data_format |
|
if self.data_format not in ["channels_last", "channels_first"]: |
|
raise NotImplementedError |
|
self.normalized_shape = (normalized_shape,) |
|
|
|
def forward(self, x): |
|
if self.data_format == "channels_last": |
|
return F.layer_norm( |
|
x, self.normalized_shape, self.weight, self.bias, self.eps |
|
) |
|
elif self.data_format == "channels_first": |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
|
|
|
|
class GRN(nn.Module): |
|
"""GRN (Global Response Normalization) layer""" |
|
|
|
def __init__(self, dim): |
|
super().__init__() |
|
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
|
def forward(self, x): |
|
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) |
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
|
return self.gamma * (x * Nx) + self.beta + x |
|
|
|
|
|
class Block(nn.Module): |
|
"""ConvNeXtV2 Block. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
drop_path (float): Stochastic depth rate. Default: 0.0 |
|
""" |
|
|
|
def __init__(self, dim, drop_path=0.0): |
|
super().__init__() |
|
self.dwconv = nn.Conv2d( |
|
dim, dim, kernel_size=7, padding=3, groups=dim |
|
) |
|
self.norm = LayerNorm(dim, eps=1e-6) |
|
self.pwconv1 = nn.Linear( |
|
dim, 4 * dim |
|
) |
|
self.act = nn.GELU() |
|
self.grn = GRN(4 * dim) |
|
self.pwconv2 = nn.Linear(4 * dim, dim) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward(self, x): |
|
input = x |
|
x = self.dwconv(x) |
|
x = x.permute(0, 2, 3, 1) |
|
x = self.norm(x) |
|
x = self.pwconv1(x) |
|
x = self.act(x) |
|
x = self.grn(x) |
|
x = self.pwconv2(x) |
|
x = x.permute(0, 3, 1, 2) |
|
|
|
x = input + self.drop_path(x) |
|
return x |
|
|
|
|
|
class ConvNeXtV2(nn.Module): |
|
"""ConvNeXt V2 |
|
|
|
Args: |
|
in_chans (int): Number of input image channels. Default: 3 |
|
num_classes (int): Number of classes for classification head. Default: 1000 |
|
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] |
|
dims (int[]): Feature dimension at each stage. Default: [96, 192, 384, 768] |
|
drop_path_rate (float): Stochastic depth rate. Default: 0. |
|
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_chans=3, |
|
num_classes=1000, |
|
depths=[3, 3, 9, 3], |
|
dims=[96, 192, 384, 768], |
|
drop_path_rate=0.0, |
|
head_init_scale=1.0, |
|
): |
|
super().__init__() |
|
self.depths = depths |
|
self.downsample_layers = ( |
|
nn.ModuleList() |
|
) |
|
stem = nn.Sequential( |
|
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), |
|
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), |
|
) |
|
self.downsample_layers.append(stem) |
|
for i in range(3): |
|
downsample_layer = nn.Sequential( |
|
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), |
|
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), |
|
) |
|
self.downsample_layers.append(downsample_layer) |
|
|
|
self.stages = ( |
|
nn.ModuleList() |
|
) |
|
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
|
cur = 0 |
|
for i in range(4): |
|
stage = nn.Sequential( |
|
*[ |
|
Block(dim=dims[i], drop_path=dp_rates[cur + j]) |
|
for j in range(depths[i]) |
|
] |
|
) |
|
self.stages.append(stage) |
|
cur += depths[i] |
|
|
|
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) |
|
self.head = nn.Linear(dims[-1], num_classes) |
|
|
|
self.apply(self._init_weights) |
|
self.head.weight.data.mul_(head_init_scale) |
|
self.head.bias.data.mul_(head_init_scale) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Conv2d, nn.Linear)): |
|
trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward_features(self, x): |
|
for i in range(4): |
|
x = self.downsample_layers[i](x) |
|
x = self.stages[i](x) |
|
return self.norm( |
|
x.mean([-2, -1]) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.forward_features(x) |
|
x = self.head(x) |
|
return x |
|
|
|
|
|
model_cfgs = { |
|
"atto": [[2, 2, 6, 2], [40, 80, 160, 320]], |
|
"femto": [[2, 2, 6, 2], [48, 96, 192, 384]], |
|
"pico": [[2, 2, 6, 2], [64, 128, 256, 512]], |
|
"nano": [[2, 2, 8, 2], [80, 160, 320, 640]], |
|
"tiny": [[3, 3, 9, 3], [96, 192, 384, 768]], |
|
"base": [[3, 3, 27, 3], [128, 256, 512, 1024]], |
|
"large": [[3, 3, 27, 3], [192, 384, 768, 1536]], |
|
"huge": [[3, 3, 27, 3], [352, 704, 1408, 2816]], |
|
} |
|
|
|
|
|
def convnextv2(cfg_name, **kwargs): |
|
cfg = model_cfgs[cfg_name] |
|
model = ConvNeXtV2(depths=cfg[0], dims=cfg[1], **kwargs) |
|
return model |
|
|
|
|
|
|
|
|
|
EXTENSION = [".png", ".jpg", ".jpeg"] |
|
|
|
|
|
def file_ext(fname): |
|
return os.path.splitext(fname)[1].lower() |
|
|
|
|
|
def rescale_pad(image, output_size, random_pad=False): |
|
h, w = image.shape[-2:] |
|
if h != output_size or w != output_size: |
|
r = min(output_size / h, output_size / w) |
|
new_h, new_w = int(h * r), int(w * r) |
|
if random_pad: |
|
r2 = random.uniform(0.9, 1) |
|
new_h, new_w = int(new_h * r2), int(new_w * r2) |
|
ph = output_size - new_h |
|
pw = output_size - new_w |
|
left = random.randint(0, pw) if random_pad else pw // 2 |
|
right = pw - left |
|
top = random.randint(0, ph) if random_pad else ph // 2 |
|
bottom = ph - top |
|
image = transforms.functional.resize(image, [new_h, new_w]) |
|
image = transforms.functional.pad( |
|
image, [left, top, right, bottom], random.uniform(0, 1) if random_pad else 0 |
|
) |
|
return image |
|
|
|
|
|
def random_crop(image, min_rate=0.8): |
|
h, w = image.shape[-2:] |
|
new_h, new_w = int(h * random.uniform(min_rate, 1)), int(w * random.uniform(min_rate, 1)) |
|
top = np.random.randint(0, h - new_h) |
|
left = np.random.randint(0, w - new_w) |
|
image = image[:, top: top + new_h, left: left + new_w] |
|
return image |
|
|
|
|
|
class AnimeAestheticDataset(Dataset): |
|
def __init__(self, path, img_size, xflip=True): |
|
all_files = { |
|
os.path.relpath(os.path.join(root, fname), path) |
|
for root, _dirs, files in os.walk(path) |
|
for fname in files |
|
} |
|
all_images = sorted( |
|
fname for fname in all_files if file_ext(fname) in EXTENSION |
|
) |
|
with open(os.path.join(path, "label.json"), "r", encoding="utf8") as f: |
|
labels = json.load(f) |
|
image_list = [] |
|
label_list = [] |
|
for fname in all_images: |
|
if fname not in labels: |
|
continue |
|
image_list.append(fname) |
|
label_list.append(labels[fname]) |
|
self.path = path |
|
self.img_size = img_size |
|
self.xflip = xflip |
|
self.image_list = image_list |
|
self.label_list = label_list |
|
|
|
def __len__(self): |
|
length = len(self.image_list) |
|
if self.xflip: |
|
length *= 2 |
|
return length |
|
|
|
def __getitem__(self, index): |
|
real_len = len(self.image_list) |
|
fname = self.image_list[index % real_len] |
|
label = self.label_list[index % real_len] |
|
image = Image.open(os.path.join(self.path, fname)).convert("RGB") |
|
image = transforms.functional.to_tensor(image) |
|
image = random_crop(image, 0.8) |
|
image = rescale_pad(image, self.img_size, True) |
|
if index // real_len != 0: |
|
image = transforms.functional.hflip(image) |
|
label = torch.tensor([label], dtype=torch.float32) |
|
return image, label |
|
|
|
|
|
|
|
|
|
|
|
class AnimeAesthetic(pl.LightningModule): |
|
def __init__(self, cfg: str, drop_path_rate=0.0, ema_decay=0): |
|
super().__init__() |
|
self.net = convnextv2(cfg, in_chans=3, num_classes=1, drop_path_rate=drop_path_rate) |
|
self.ema_decay = ema_decay |
|
self.ema = None |
|
if ema_decay > 0: |
|
self.ema = deepcopy(self.net) |
|
self.ema.requires_grad_(False) |
|
|
|
def configure_optimizers(self): |
|
optimizer = optim.Adam( |
|
self.net.parameters(), |
|
lr=0.001, |
|
betas=(0.9, 0.999), |
|
eps=1e-08, |
|
weight_decay=0, |
|
) |
|
return optimizer |
|
|
|
def forward(self, x, use_ema=False): |
|
x = (x - 0.5) / 0.5 |
|
net = self.ema if use_ema else self.net |
|
return net(x) |
|
|
|
def training_step(self, batch, batch_idx): |
|
images, labels = batch |
|
loss = F.mse_loss(self.forward(images, False), labels) |
|
self.log_dict({"train/loss": loss}) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
images, labels = batch |
|
mae = F.l1_loss(self.forward(images, False), labels) |
|
logs = {"val/mae": mae} |
|
if self.ema is not None: |
|
mae_ema = F.l1_loss(self.forward(images, True), labels) |
|
logs["val/mae_ema"] = mae_ema |
|
self.log_dict(logs, sync_dist=True) |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx): |
|
if self.ema is not None: |
|
with torch.no_grad(): |
|
for ema_v, model_v in zip( |
|
self.ema.state_dict().values(), self.net.state_dict().values() |
|
): |
|
ema_v.copy_( |
|
self.ema_decay * ema_v + (1.0 - self.ema_decay) * model_v |
|
) |
|
|
|
|
|
def main(opt): |
|
if not os.path.exists("lightning_logs"): |
|
os.mkdir("lightning_logs") |
|
torch.manual_seed(0) |
|
np.random.seed(0) |
|
print("---load dataset---") |
|
full_dataset = AnimeAestheticDataset(opt.data, opt.img_size) |
|
full_dataset_len = len(full_dataset) |
|
train_dataset_len = int(full_dataset_len * opt.data_split) |
|
val_dataset_len = full_dataset_len - train_dataset_len |
|
train_dataset, val_dataset = torch.utils.data.random_split( |
|
full_dataset, [train_dataset_len, val_dataset_len] |
|
) |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=opt.batch_size_train, |
|
shuffle=True, |
|
persistent_workers=True, |
|
num_workers=opt.workers_train, |
|
pin_memory=True, |
|
) |
|
val_dataloader = DataLoader( |
|
val_dataset, |
|
batch_size=opt.batch_size_val, |
|
shuffle=False, |
|
persistent_workers=True, |
|
num_workers=opt.workers_val, |
|
pin_memory=True, |
|
) |
|
print(f"train: {len(train_dataset)}") |
|
print(f"val: {len(val_dataset)}") |
|
print("---define model---") |
|
if opt.resume != "": |
|
anime_aesthetic = AnimeAesthetic.load_from_checkpoint( |
|
opt.resume, cfg=opt.cfg, drop_path_rate=opt.drop_path, ema_decay=opt.ema_decay |
|
) |
|
else: |
|
anime_aesthetic = AnimeAesthetic(cfg=opt.cfg, drop_path_rate=opt.drop_path, ema_decay=opt.ema_decay) |
|
|
|
print("---start train---") |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor="val/mae", |
|
mode="min", |
|
save_top_k=1, |
|
save_last=True, |
|
auto_insert_metric_name=False, |
|
filename="epoch={epoch},mae={val/mae:.4f}", |
|
) |
|
callbacks = [checkpoint_callback] |
|
if opt.ema_decay > 0: |
|
checkpoint_ema_callback = ModelCheckpoint( |
|
monitor="val/mae_ema", |
|
mode="min", |
|
save_top_k=1, |
|
save_last=False, |
|
auto_insert_metric_name=False, |
|
filename="epoch={epoch},mae-ema={val/mae_ema:.4f}", |
|
) |
|
callbacks.append(checkpoint_ema_callback) |
|
trainer = Trainer( |
|
precision=32 if opt.fp32 else 16, |
|
accelerator=opt.accelerator, |
|
devices=opt.devices, |
|
max_epochs=opt.epoch, |
|
benchmark=opt.benchmark, |
|
accumulate_grad_batches=opt.acc_step, |
|
val_check_interval=opt.val_epoch, |
|
log_every_n_steps=opt.log_step, |
|
strategy="ddp_find_unused_parameters_false" if opt.devices > 1 else None, |
|
callbacks=callbacks, |
|
) |
|
trainer.fit(anime_aesthetic, train_dataloader, val_dataloader) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--cfg", |
|
type=str, |
|
default="tiny", |
|
choices=list(model_cfgs.keys()), |
|
help="model configure", |
|
) |
|
parser.add_argument( |
|
"--resume", type=str, default="", help="resume training from ckpt" |
|
) |
|
parser.add_argument( |
|
"--img-size", |
|
type=int, |
|
default=768, |
|
help="image size for training and validation", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--data", type=str, default="./data", help="dataset path" |
|
) |
|
parser.add_argument( |
|
"--data-split", |
|
type=float, |
|
default=0.9999, |
|
help="split rate for training and validation", |
|
) |
|
|
|
|
|
parser.add_argument("--epoch", type=int, default=100, help="epoch num") |
|
parser.add_argument( |
|
"--batch-size-train", type=int, default=16, help="batch size for training" |
|
) |
|
parser.add_argument( |
|
"--batch-size-val", type=int, default=2, help="batch size for val" |
|
) |
|
parser.add_argument( |
|
"--workers-train", |
|
type=int, |
|
default=4, |
|
help="workers num for training dataloader", |
|
) |
|
parser.add_argument( |
|
"--workers-val", |
|
type=int, |
|
default=4, |
|
help="workers num for validation dataloader", |
|
) |
|
parser.add_argument( |
|
"--acc-step", type=int, default=8, help="gradient accumulation step" |
|
) |
|
parser.add_argument( |
|
"--drop-path", type=float, default=0.1, help="Drop path rate" |
|
) |
|
parser.add_argument( |
|
"--ema-decay", type=float, default=0.9999, help="use ema if ema-decay > 0" |
|
) |
|
parser.add_argument( |
|
"--accelerator", |
|
type=str, |
|
default="gpu", |
|
choices=["cpu", "gpu", "tpu", "ipu", "hpu", "auto"], |
|
help="accelerator", |
|
) |
|
parser.add_argument("--devices", type=int, default=4, help="devices num") |
|
parser.add_argument( |
|
"--fp32", action="store_true", default=False, help="disable mix precision" |
|
) |
|
parser.add_argument( |
|
"--benchmark", action="store_true", default=True, help="enable cudnn benchmark" |
|
) |
|
parser.add_argument( |
|
"--log-step", type=int, default=2, help="log training loss every n steps" |
|
) |
|
parser.add_argument( |
|
"--val-epoch", type=int, default=0.025, help="valid and save every n epoch" |
|
) |
|
|
|
opt = parser.parse_args() |
|
print(opt) |
|
|
|
main(opt) |
|
|