|
""" |
|
Modified from: |
|
https://github.com/NVlabs/LSGM/blob/main/training_obj_joint.py |
|
""" |
|
import copy |
|
import functools |
|
import json |
|
import os |
|
from pathlib import Path |
|
from pdb import set_trace as st |
|
from typing import Any |
|
|
|
import blobfile as bf |
|
import imageio |
|
import numpy as np |
|
import torch as th |
|
import torch.distributed as dist |
|
import torchvision |
|
from PIL import Image |
|
from torch.nn.parallel.distributed import DistributedDataParallel as DDP |
|
from torch.optim import AdamW |
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
from tqdm import tqdm |
|
|
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.fp16_util import MixedPrecisionTrainer |
|
from guided_diffusion.nn import update_ema |
|
from guided_diffusion.resample import LossAwareSampler, UniformSampler |
|
|
|
from guided_diffusion.train_util import (TrainLoop, calc_average_loss, |
|
find_ema_checkpoint, |
|
find_resume_checkpoint, |
|
get_blob_logdir, log_loss_dict, |
|
log_rec3d_loss_dict, |
|
parse_resume_step_from_filename) |
|
from guided_diffusion.gaussian_diffusion import ModelMeanType |
|
|
|
import dnnlib |
|
from dnnlib.util import calculate_adaptive_weight |
|
|
|
from ..train_util_diffusion import TrainLoop3DDiffusion |
|
from ..cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD |
|
|
|
|
|
class TrainLoop3DDiffusionLSGM(TrainLoop3DDiffusion,TrainLoop3DcvD_nvsD_canoD): |
|
def __init__(self, *, rec_model, denoise_model, diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, schedule_sampler=None, weight_decay=0, lr_anneal_steps=0, iterations=10001, ignore_resume_opt=False, freeze_ae=False, denoised_ae=True, triplane_scaling_divider=10, use_amp=False, diffusion_input_size=224, **kwargs): |
|
super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, ignore_resume_opt=ignore_resume_opt, freeze_ae=freeze_ae, denoised_ae=denoised_ae, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, **kwargs) |
|
|
|
def run_step(self, batch, step='g_step'): |
|
|
|
if step == 'diffusion_step_rec': |
|
self.forward_diffusion(batch, behaviour='diffusion_step_rec') |
|
_ = self.mp_trainer_rec.optimize(self.opt_rec) |
|
took_step_ddpm = self.mp_trainer.optimize(self.opt) |
|
|
|
if took_step_ddpm: |
|
self._update_ema() |
|
|
|
elif step == 'd_step_rec': |
|
self.forward_D(batch, behaviour='rec') |
|
|
|
_ = self.mp_trainer_canonical_cvD.optimize(self.opt_cano_cvD) |
|
|
|
elif step == 'diffusion_step_nvs': |
|
self.forward_diffusion(batch, behaviour='diffusion_step_nvs') |
|
_ = self.mp_trainer_rec.optimize(self.opt_rec) |
|
took_step_ddpm = self.mp_trainer.optimize(self.opt) |
|
|
|
if took_step_ddpm: |
|
self._update_ema() |
|
|
|
elif step == 'd_step_nvs': |
|
self.forward_D(batch, behaviour='nvs') |
|
_ = self.mp_trainer_cvD.optimize(self.opt_cvD) |
|
|
|
self._anneal_lr() |
|
self.log_step() |
|
|
|
def run_loop(self): |
|
while (not self.lr_anneal_steps |
|
or self.step + self.resume_step < self.lr_anneal_steps): |
|
|
|
|
|
dist_util.synchronize() |
|
|
|
|
|
|
|
|
|
|
|
|
|
batch = next(self.data) |
|
self.run_step(batch, step='diffusion_step_rec') |
|
|
|
batch = next(self.data) |
|
self.run_step(batch, 'd_step_rec') |
|
|
|
|
|
|
|
|
|
batch = next(self.data) |
|
self.run_step(batch, step='diffusion_step_nvs') |
|
|
|
batch = next(self.data) |
|
self.run_step(batch, 'd_step_nvs') |
|
|
|
if self.step % self.log_interval == 0 and dist_util.get_rank( |
|
) == 0: |
|
out = logger.dumpkvs() |
|
|
|
for k, v in out.items(): |
|
self.writer.add_scalar(f'Loss/{k}', v, |
|
self.step + self.resume_step) |
|
|
|
|
|
if self.step % self.eval_interval == 0: |
|
if dist_util.get_rank() == 0: |
|
self.eval_loop() |
|
|
|
|
|
th.cuda.empty_cache() |
|
dist_util.synchronize() |
|
|
|
if self.step % self.save_interval == 0: |
|
self.save(self.mp_trainer, self.mp_trainer.model_name) |
|
self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name) |
|
self.save(self.mp_trainer_cvD, 'cvD') |
|
self.save(self.mp_trainer_canonical_cvD, 'cano_cvD') |
|
|
|
dist_util.synchronize() |
|
|
|
if os.environ.get("DIFFUSION_TRAINING_TEST", |
|
"") and self.step > 0: |
|
return |
|
|
|
self.step += 1 |
|
|
|
if self.step > self.iterations: |
|
print('reached maximum iterations, exiting') |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
|
|
self.save(self.mp_trainer, self.mp_trainer.model_name) |
|
self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name) |
|
self.save(self.mp_trainer_cvD, 'cvD') |
|
self.save(self.mp_trainer_canonical_cvD, 'cano_cvD') |
|
|
|
exit() |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
self.save() |
|
self.save(self.mp_trainer_canonical_cvD, 'cvD') |
|
|
|
def forward_diffusion(self, batch, behaviour='rec', *args, **kwargs): |
|
""" |
|
add sds grad to all ae predicted x_0 |
|
""" |
|
|
|
self.ddp_cano_cvD.requires_grad_(False) |
|
self.ddp_nvs_cvD.requires_grad_(False) |
|
|
|
self.ddp_model.requires_grad_(True) |
|
self.ddp_rec_model.requires_grad_(True) |
|
|
|
|
|
|
|
|
|
for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( |
|
): |
|
param.requires_grad_(False) |
|
|
|
|
|
self.mp_trainer_rec.zero_grad() |
|
self.mp_trainer.zero_grad() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = batch['img'].shape[0] |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
|
|
micro = { |
|
k: v[i:i + self.microbatch].to(dist_util.dev()) |
|
for k, v in batch.items() |
|
} |
|
|
|
last_batch = (i + self.microbatch) >= batch_size |
|
|
|
vae_nelbo_loss = th.tensor(0.0).to(dist_util.dev()) |
|
vision_aided_loss = th.tensor(0.0).to(dist_util.dev()) |
|
denoise_loss = th.tensor(0.0).to(dist_util.dev()) |
|
d_weight = th.tensor(0.0).to(dist_util.dev()) |
|
|
|
|
|
with th.cuda.amp.autocast(dtype=th.float16, |
|
enabled=self.mp_trainer.use_amp |
|
and not self.freeze_ae): |
|
|
|
|
|
vae_out = self.ddp_rec_model( |
|
img=micro['img_to_encoder'], |
|
c=micro['c'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
|
|
if behaviour == 'diffusion_step_rec': |
|
target = micro |
|
pred = self.ddp_rec_model(latent=vae_out, |
|
c=micro['c'], |
|
behaviour='triplane_dec') |
|
|
|
|
|
if last_batch or not self.use_ddp: |
|
vae_nelbo_loss, loss_dict = self.loss_class(pred, |
|
target, |
|
test_mode=False) |
|
else: |
|
with self.ddp_model.no_sync(): |
|
vae_nelbo_loss, loss_dict = self.loss_class( |
|
pred, target, test_mode=False) |
|
|
|
last_layer = self.ddp_rec_model.module.decoder.triplane_decoder.decoder.net[ |
|
-1].weight |
|
|
|
if 'image_sr' in pred: |
|
vision_aided_loss = self.ddp_cano_cvD( |
|
0.5 * pred['image_sr'] + |
|
0.5 * th.nn.functional.interpolate( |
|
pred['image_raw'], |
|
size=pred['image_sr'].shape[2:], |
|
mode='bilinear'), |
|
for_G=True).mean() |
|
else: |
|
vision_aided_loss = self.ddp_cano_cvD( |
|
pred['image_raw'], for_G=True |
|
).mean( |
|
) |
|
|
|
d_weight = calculate_adaptive_weight( |
|
vae_nelbo_loss, |
|
vision_aided_loss, |
|
last_layer, |
|
|
|
disc_weight_max=1) * self.loss_class.opt.rec_cvD_lambda |
|
|
|
|
|
vision_aided_loss *= d_weight |
|
|
|
|
|
loss_dict.update({ |
|
'vision_aided_loss/G_rec': |
|
vision_aided_loss, |
|
'd_weight_G_rec': |
|
d_weight, |
|
}) |
|
|
|
log_rec3d_loss_dict(loss_dict) |
|
|
|
elif behaviour == 'diffusion_step_nvs': |
|
|
|
novel_view_c = th.cat([micro['c'][1:], micro['c'][:1]]) |
|
|
|
pred = self.ddp_rec_model(latent=vae_out, |
|
c=novel_view_c, |
|
behaviour='triplane_dec') |
|
|
|
if 'image_sr' in pred: |
|
vision_aided_loss = self.ddp_nvs_cvD( |
|
|
|
0.5 * pred['image_sr'] + |
|
0.5 * th.nn.functional.interpolate( |
|
pred['image_raw'], |
|
size=pred['image_sr'].shape[2:], |
|
mode='bilinear'), |
|
for_G=True).mean() |
|
else: |
|
vision_aided_loss = self.ddp_nvs_cvD( |
|
pred['image_raw'], for_G=True |
|
).mean( |
|
) |
|
|
|
d_weight = self.loss_class.opt.nvs_cvD_lambda |
|
vision_aided_loss *= d_weight |
|
|
|
log_rec3d_loss_dict({ |
|
'vision_aided_loss/G_nvs': |
|
vision_aided_loss, |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
raise NotImplementedError(behaviour) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eps = vae_out[self.latent_name] |
|
|
|
|
|
|
|
eps.requires_grad_(True) |
|
|
|
t, weights = self.schedule_sampler.sample( |
|
eps.shape[0], dist_util.dev()) |
|
noise = th.randn(size=vae_out.size(), device='cuda') |
|
|
|
model_kwargs = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \ |
|
diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_p, args.iw_subvp_like_vp_sde) |
|
eps_t_p = diffusion.sample_q(vae_out, noise, var_t_p, m_t_p) |
|
|
|
|
|
if args.iw_sample_q in ['ll_uniform', 'll_iw']: |
|
t_q, var_t_q, m_t_q, obj_weight_t_q, _, g2_t_q = \ |
|
diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_q, args.iw_subvp_like_vp_sde) |
|
eps_t_q = diffusion.sample_q(vae_out, noise, var_t_q, m_t_q) |
|
|
|
eps_t_p = eps_t_p.detach().requires_grad_(True) |
|
eps_t = th.cat([eps_t_p, eps_t_q], dim=0) |
|
var_t = th.cat([var_t_p, var_t_q], dim=0) |
|
t = th.cat([t_p, t_q], dim=0) |
|
noise = th.cat([noise, noise], dim=0) |
|
else: |
|
eps_t, m_t, var_t, t, g2_t = eps_t_p, m_t_p, var_t_p, t_p, g2_t_p |
|
|
|
|
|
|
|
|
|
|
|
mixing_component = diffusion.mixing_component(eps_t, var_t, t, enabled=dae.mixed_prediction) |
|
params = utils.get_mixed_prediction(dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cross_entropy_per_var += diffusion.cross_entropy_const(args.time_eps) |
|
cross_entropy = th.sum(cross_entropy_per_var, dim=[1, 2, 3]) |
|
cross_entropy += remaining_neg_log_p_total |
|
all_neg_log_p = vae.decompose_eps(cross_entropy_per_var) |
|
all_neg_log_p.extend(remaining_neg_log_p_per_ver) |
|
kl_all_list, kl_vals_per_group, kl_diag_list = utils.kl_per_group_vada(all_log_q, all_neg_log_p) |
|
|
|
|
|
kl_coeff = 1.0 |
|
|
|
|
|
|
|
|
|
q_loss = th.mean(nelbo_loss) |
|
p_loss = th.mean(p_objective) |
|
|
|
|
|
if args.train_vae: |
|
grad_scalar.scale(q_loss).backward(retain_graph=utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q)) |
|
utils.average_gradients(vae.parameters(), args.distributed) |
|
if args.grad_clip_max_norm > 0.: |
|
grad_scalar.unscale_(vae_optimizer) |
|
th.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=args.grad_clip_max_norm) |
|
grad_scalar.step(vae_optimizer) |
|
|
|
|
|
if utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q) or not args.train_vae: |
|
if args.train_vae: |
|
|
|
dae_optimizer.zero_grad() |
|
|
|
|
|
grad_scalar.scale(p_loss).backward() |
|
|
|
|
|
utils.average_gradients(dae.parameters(), args.distributed) |
|
if args.grad_clip_max_norm > 0.: |
|
grad_scalar.unscale_(dae_optimizer) |
|
th.nn.utils.clip_grad_norm_(dae.parameters(), max_norm=args.grad_clip_max_norm) |
|
grad_scalar.step(dae_optimizer) |
|
|
|
|
|
|
|
if args.iw_sample_q in ['ll_uniform', 'll_iw']: |
|
l2_term_p, l2_term_q = th.chunk(l2_term, chunks=2, dim=0) |
|
p_objective = th.sum(obj_weight_t_p * l2_term_p, dim=[1, 2, 3]) |
|
|
|
else: |
|
p_objective = th.sum(obj_weight_t_p * l2_term, dim=[1, 2, 3]) |
|
|
|
|
|
|
|
compute_losses = functools.partial( |
|
self.diffusion.training_losses, |
|
self.ddp_model, |
|
eps, |
|
t, |
|
model_kwargs=model_kwargs, |
|
return_detail=True) |
|
|
|
|
|
if last_batch or not self.use_ddp: |
|
losses = compute_losses() |
|
|
|
else: |
|
with self.ddp_model.no_sync(): |
|
losses = compute_losses() |
|
|
|
if isinstance(self.schedule_sampler, LossAwareSampler): |
|
self.schedule_sampler.update_with_local_losses( |
|
t, losses["loss"].detach()) |
|
|
|
denoise_loss = (losses["loss"] * weights).mean() |
|
|
|
x_t = losses.pop('x_t') |
|
model_output = losses.pop('model_output') |
|
diffusion_target = losses.pop('diffusion_target') |
|
alpha_bar = losses.pop('alpha_bar') |
|
|
|
log_loss_dict(self.diffusion, t, |
|
{k: v * weights |
|
for k, v in losses.items()}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = vae_nelbo_loss + denoise_loss + vision_aided_loss |
|
|
|
|
|
|
|
|
|
self.mp_trainer_rec.backward(loss) |
|
self.mp_trainer.backward(loss) |
|
|
|
|
|
|
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 0 and behaviour != 'diff': |
|
with th.no_grad(): |
|
|
|
|
|
|
|
|
|
gt_depth = micro['depth'] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth.unsqueeze(1) |
|
gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - |
|
gt_depth.min()) |
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / ( |
|
pred_depth.max() - pred_depth.min()) |
|
pred_img = pred['image_raw'] |
|
gt_img = micro['img'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_vis = th.cat( |
|
[ |
|
gt_img, micro['img'], micro['img'], |
|
gt_depth.repeat_interleave(3, dim=1) |
|
], |
|
dim=-1)[0:1] |
|
|
|
noised_ae_pred = self.ddp_rec_model( |
|
img=None, |
|
c=micro['c'][0:1], |
|
latent=x_t[0:1] * self. |
|
triplane_scaling_divider, |
|
behaviour=self.render_latent_behaviour) |
|
|
|
|
|
|
|
|
|
|
|
if self.diffusion.model_mean_type == ModelMeanType.START_X: |
|
pred_xstart = model_output |
|
else: |
|
pred_xstart = self.diffusion._predict_xstart_from_eps( |
|
x_t=x_t, t=t, eps=model_output) |
|
|
|
denoised_ae_pred = self.ddp_rec_model( |
|
img=None, |
|
c=micro['c'][0:1], |
|
latent=pred_xstart[0:1] * self. |
|
triplane_scaling_divider, |
|
behaviour=self.render_latent_behaviour) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_vis = th.cat([ |
|
pred_img[0:1], noised_ae_pred['image_raw'][0:1], |
|
denoised_ae_pred['image_raw'][0:1], |
|
pred_depth[0:1].repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
|
|
vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( |
|
1, 2, 0).cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vis = vis.numpy() * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
Image.fromarray(vis).save( |
|
f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg' |
|
) |
|
print( |
|
'log denoised vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg' |
|
) |
|
|
|
th.cuda.empty_cache() |
|
|