ICON / apps /Normal.py
Yuliang's picture
done
2d5f249
from lib.net import NormalNet
from lib.common.train_util import *
import logging
import torch
import numpy as np
from torch import nn
from skimage.transform import resize
import pytorch_lightning as pl
torch.backends.cudnn.benchmark = True
logging.getLogger("lightning").setLevel(logging.ERROR)
class Normal(pl.LightningModule):
def __init__(self, cfg):
super(Normal, self).__init__()
self.cfg = cfg
self.batch_size = self.cfg.batch_size
self.lr_N = self.cfg.lr_N
self.schedulers = []
self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss())
self.in_nml = [item[0] for item in cfg.net.in_nml]
def get_progress_bar_dict(self):
tqdm_dict = super().get_progress_bar_dict()
if "v_num" in tqdm_dict:
del tqdm_dict["v_num"]
return tqdm_dict
# Training related
def configure_optimizers(self):
# set optimizer
weight_decay = self.cfg.weight_decay
momentum = self.cfg.momentum
optim_params_N_F = [
{"params": self.netG.netF.parameters(), "lr": self.lr_N}]
optim_params_N_B = [
{"params": self.netG.netB.parameters(), "lr": self.lr_N}]
optimizer_N_F = torch.optim.Adam(
optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
)
optimizer_N_B = torch.optim.Adam(
optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
)
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
)
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
)
self.schedulers = [scheduler_N_F, scheduler_N_B]
optims = [optimizer_N_F, optimizer_N_B]
return optims, self.schedulers
def render_func(self, render_tensor):
height = render_tensor["image"].shape[2]
result_list = []
for name in render_tensor.keys():
result_list.append(
resize(
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
1, 2, 0
),
(height, height),
anti_aliasing=True,
)
)
result_array = np.concatenate(result_list, axis=1)
return result_array
def training_step(self, batch, batch_idx, optimizer_idx):
export_cfg(self.logger, self.cfg)
# retrieve the data
in_tensor = {}
for name in self.in_nml:
in_tensor[name] = batch[name]
FB_tensor = {"normal_F": batch["normal_F"],
"normal_B": batch["normal_B"]}
self.netG.train()
preds_F, preds_B = self.netG(in_tensor)
error_NF, error_NB = self.netG.get_norm_error(
preds_F, preds_B, FB_tensor)
(opt_nf, opt_nb) = self.optimizers()
opt_nf.zero_grad()
opt_nb.zero_grad()
self.manual_backward(error_NF, opt_nf)
self.manual_backward(error_NB, opt_nb)
opt_nf.step()
opt_nb.step()
if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0:
self.netG.eval()
with torch.no_grad():
nmlF, nmlB = self.netG(in_tensor)
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
result_array = self.render_func(in_tensor)
self.logger.experiment.add_image(
tag=f"Normal-train/{self.global_step}",
img_tensor=result_array.transpose(2, 0, 1),
global_step=self.global_step,
)
# metrics processing
metrics_log = {
"train_loss-NF": error_NF.item(),
"train_loss-NB": error_NB.item(),
}
tf_log = tf_log_convert(metrics_log)
bar_log = bar_log_convert(metrics_log)
return {
"loss": error_NF + error_NB,
"loss-NF": error_NF,
"loss-NB": error_NB,
"log": tf_log,
"progress_bar": bar_log,
}
def training_epoch_end(self, outputs):
if [] in outputs:
outputs = outputs[0]
# metrics processing
metrics_log = {
"train_avgloss": batch_mean(outputs, "loss"),
"train_avgloss-NF": batch_mean(outputs, "loss-NF"),
"train_avgloss-NB": batch_mean(outputs, "loss-NB"),
}
tf_log = tf_log_convert(metrics_log)
tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0]
tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0]
return {"log": tf_log}
def validation_step(self, batch, batch_idx):
# retrieve the data
in_tensor = {}
for name in self.in_nml:
in_tensor[name] = batch[name]
FB_tensor = {"normal_F": batch["normal_F"],
"normal_B": batch["normal_B"]}
self.netG.train()
preds_F, preds_B = self.netG(in_tensor)
error_NF, error_NB = self.netG.get_norm_error(
preds_F, preds_B, FB_tensor)
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
batch_idx == 0
):
with torch.no_grad():
nmlF, nmlB = self.netG(in_tensor)
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
result_array = self.render_func(in_tensor)
self.logger.experiment.add_image(
tag=f"Normal-val/{self.global_step}",
img_tensor=result_array.transpose(2, 0, 1),
global_step=self.global_step,
)
return {
"val_loss": error_NF + error_NB,
"val_loss-NF": error_NF,
"val_loss-NB": error_NB,
}
def validation_epoch_end(self, outputs):
# metrics processing
metrics_log = {
"val_avgloss": batch_mean(outputs, "val_loss"),
"val_avgloss-NF": batch_mean(outputs, "val_loss-NF"),
"val_avgloss-NB": batch_mean(outputs, "val_loss-NB"),
}
tf_log = tf_log_convert(metrics_log)
return {"log": tf_log}