Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
from omegaconf import DictConfig | |
from typing import List, Tuple, Dict, Optional, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import lr_scheduler | |
import pytorch_lightning as pl | |
from pytorch_lightning.utilities import rank_zero_only | |
from einops import rearrange | |
from diffusers.schedulers import ( | |
DDPMScheduler, | |
DDIMScheduler, | |
KarrasVeScheduler, | |
DPMSolverMultistepScheduler | |
) | |
from michelangelo.utils import instantiate_from_config | |
# from michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule | |
from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule | |
from michelangelo.models.asl_diffusion.inference_utils import ddim_sample | |
SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
class ASLDiffuser(pl.LightningModule): | |
first_stage_model: Optional[AlignedShapeAsLatentPLModule] | |
# cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] | |
model: nn.Module | |
def __init__(self, *, | |
first_stage_config, | |
denoiser_cfg, | |
scheduler_cfg, | |
optimizer_cfg, | |
loss_cfg, | |
first_stage_key: str = "surface", | |
cond_stage_key: str = "image", | |
cond_stage_trainable: bool = True, | |
scale_by_std: bool = False, | |
z_scale_factor: float = 1.0, | |
ckpt_path: Optional[str] = None, | |
ignore_keys: Union[Tuple[str], List[str]] = ()): | |
super().__init__() | |
self.first_stage_key = first_stage_key | |
self.cond_stage_key = cond_stage_key | |
self.cond_stage_trainable = cond_stage_trainable | |
# 1. initialize first stage. | |
# Note: the condition model contained in the first stage model. | |
self.first_stage_config = first_stage_config | |
self.first_stage_model = None | |
# self.instantiate_first_stage(first_stage_config) | |
# 2. initialize conditional stage | |
# self.instantiate_cond_stage(cond_stage_config) | |
self.cond_stage_model = { | |
"image": self.encode_image, | |
"image_unconditional_embedding": self.empty_img_cond, | |
"text": self.encode_text, | |
"text_unconditional_embedding": self.empty_text_cond, | |
"surface": self.encode_surface, | |
"surface_unconditional_embedding": self.empty_surface_cond, | |
} | |
# 3. diffusion model | |
self.model = instantiate_from_config( | |
denoiser_cfg, device=None, dtype=None | |
) | |
self.optimizer_cfg = optimizer_cfg | |
# 4. scheduling strategy | |
self.scheduler_cfg = scheduler_cfg | |
self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) | |
self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) | |
# 5. loss configures | |
self.loss_cfg = loss_cfg | |
self.scale_by_std = scale_by_std | |
if scale_by_std: | |
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) | |
else: | |
self.z_scale_factor = z_scale_factor | |
self.ckpt_path = ckpt_path | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
def instantiate_first_stage(self, config): | |
model = instantiate_from_config(config) | |
self.first_stage_model = model.eval() | |
self.first_stage_model.train = disabled_train | |
for param in self.first_stage_model.parameters(): | |
param.requires_grad = False | |
self.first_stage_model = self.first_stage_model.to(self.device) | |
# def instantiate_cond_stage(self, config): | |
# if not self.cond_stage_trainable: | |
# if config == "__is_first_stage__": | |
# print("Using first stage also as cond stage.") | |
# self.cond_stage_model = self.first_stage_model | |
# elif config == "__is_unconditional__": | |
# print(f"Training {self.__class__.__name__} as an unconditional model.") | |
# self.cond_stage_model = None | |
# # self.be_unconditional = True | |
# else: | |
# model = instantiate_from_config(config) | |
# self.cond_stage_model = model.eval() | |
# self.cond_stage_model.train = disabled_train | |
# for param in self.cond_stage_model.parameters(): | |
# param.requires_grad = False | |
# else: | |
# assert config != "__is_first_stage__" | |
# assert config != "__is_unconditional__" | |
# model = instantiate_from_config(config) | |
# self.cond_stage_model = model | |
def init_from_ckpt(self, path, ignore_keys=()): | |
state_dict = torch.load(path, map_location="cpu")["state_dict"] | |
keys = list(state_dict.keys()) | |
for k in keys: | |
for ik in ignore_keys: | |
if k.startswith(ik): | |
print("Deleting key {} from state_dict.".format(k)) | |
del state_dict[k] | |
missing, unexpected = self.load_state_dict(state_dict, strict=False) | |
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") | |
if len(missing) > 0: | |
print(f"Missing Keys: {missing}") | |
print(f"Unexpected Keys: {unexpected}") | |
def zero_rank(self): | |
if self._trainer: | |
zero_rank = self.trainer.local_rank == 0 | |
else: | |
zero_rank = True | |
return zero_rank | |
def configure_optimizers(self) -> Tuple[List, List]: | |
lr = self.learning_rate | |
trainable_parameters = list(self.model.parameters()) | |
# if the conditional encoder is trainable | |
# if self.cond_stage_trainable: | |
# conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad] | |
# trainable_parameters += conditioner_params | |
# print(f"number of trainable conditional parameters: {len(conditioner_params)}.") | |
if self.optimizer_cfg is None: | |
optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] | |
schedulers = [] | |
else: | |
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) | |
scheduler_func = instantiate_from_config( | |
self.optimizer_cfg.scheduler, | |
max_decay_steps=self.trainer.max_steps, | |
lr_max=lr | |
) | |
scheduler = { | |
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), | |
"interval": "step", | |
"frequency": 1 | |
} | |
optimizers = [optimizer] | |
schedulers = [scheduler] | |
return optimizers, schedulers | |
def encode_text(self, text): | |
b = text.shape[0] | |
text_tokens = rearrange(text, "b t l -> (b t) l") | |
text_embed = self.first_stage_model.model.encode_text_embed(text_tokens) | |
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) | |
text_embed = text_embed.mean(dim=1) | |
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) | |
return text_embed | |
def encode_image(self, img): | |
return self.first_stage_model.model.encode_image_embed(img) | |
def encode_surface(self, surface): | |
return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False) | |
def empty_text_cond(self, cond): | |
return torch.zeros_like(cond, device=cond.device) | |
def empty_img_cond(self, cond): | |
return torch.zeros_like(cond, device=cond.device) | |
def empty_surface_cond(self, cond): | |
return torch.zeros_like(cond, device=cond.device) | |
def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): | |
z_q = self.first_stage_model.encode(surface, sample_posterior) | |
z_q = self.z_scale_factor * z_q | |
return z_q | |
def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): | |
z_q = 1. / self.z_scale_factor * z_q | |
latents = self.first_stage_model.decode(z_q, **kwargs) | |
return latents | |
def on_train_batch_start(self, batch, batch_idx): | |
# only for very first batch | |
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ | |
and batch_idx == 0 and self.ckpt_path is None: | |
# set rescale weight to 1./std of encodings | |
print("### USING STD-RESCALING ###") | |
z_q = self.encode_first_stage(batch[self.first_stage_key]) | |
z = z_q.detach() | |
del self.z_scale_factor | |
self.register_buffer("z_scale_factor", 1. / z.flatten().std()) | |
print(f"setting self.z_scale_factor to {self.z_scale_factor}") | |
print("### USING STD-RESCALING ###") | |
def compute_loss(self, model_outputs, split): | |
""" | |
Args: | |
model_outputs (dict): | |
- x_0: | |
- noise: | |
- noise_prior: | |
- noise_pred: | |
- noise_pred_prior: | |
split (str): | |
Returns: | |
""" | |
pred = model_outputs["pred"] | |
if self.noise_scheduler.prediction_type == "epsilon": | |
target = model_outputs["noise"] | |
elif self.noise_scheduler.prediction_type == "sample": | |
target = model_outputs["x_0"] | |
else: | |
raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") | |
if self.loss_cfg.loss_type == "l1": | |
simple = F.l1_loss(pred, target, reduction="mean") | |
elif self.loss_cfg.loss_type in ["mse", "l2"]: | |
simple = F.mse_loss(pred, target, reduction="mean") | |
else: | |
raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") | |
total_loss = simple | |
loss_dict = { | |
f"{split}/total_loss": total_loss.clone().detach(), | |
f"{split}/simple": simple.detach(), | |
} | |
return total_loss, loss_dict | |
def forward(self, batch): | |
""" | |
Args: | |
batch: | |
Returns: | |
""" | |
if self.first_stage_model is None: | |
self.instantiate_first_stage(self.first_stage_config) | |
latents = self.encode_first_stage(batch[self.first_stage_key]) | |
# conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) | |
conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1) | |
mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1 | |
conditions = conditions * mask.to(conditions) | |
# Sample noise that we"ll add to the latents | |
# [batch_size, n_token, latent_dim] | |
noise = torch.randn_like(latents) | |
bs = latents.shape[0] | |
# Sample a random timestep for each motion | |
timesteps = torch.randint( | |
0, | |
self.noise_scheduler.config.num_train_timesteps, | |
(bs,), | |
device=latents.device, | |
) | |
timesteps = timesteps.long() | |
# Add noise to the latents according to the noise magnitude at each timestep | |
noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
# diffusion model forward | |
noise_pred = self.model(noisy_z, timesteps, conditions) | |
diffusion_outputs = { | |
"x_0": noisy_z, | |
"noise": noise, | |
"pred": noise_pred | |
} | |
return diffusion_outputs | |
def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], | |
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: | |
""" | |
Args: | |
batch (dict): the batch sample, and it contains: | |
- surface (torch.FloatTensor): | |
- image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] | |
- depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] | |
- normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] | |
- text (list of str): | |
batch_idx (int): | |
optimizer_idx (int): | |
Returns: | |
loss (torch.FloatTensor): | |
""" | |
diffusion_outputs = self(batch) | |
loss, loss_dict = self.compute_loss(diffusion_outputs, "train") | |
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) | |
return loss | |
def validation_step(self, batch: Dict[str, torch.FloatTensor], | |
batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: | |
""" | |
Args: | |
batch (dict): the batch sample, and it contains: | |
- surface_pc (torch.FloatTensor): [n_pts, 4] | |
- surface_feats (torch.FloatTensor): [n_pts, c] | |
- text (list of str): | |
batch_idx (int): | |
optimizer_idx (int): | |
Returns: | |
loss (torch.FloatTensor): | |
""" | |
diffusion_outputs = self(batch) | |
loss, loss_dict = self.compute_loss(diffusion_outputs, "val") | |
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) | |
return loss | |
def sample(self, | |
batch: Dict[str, Union[torch.FloatTensor, List[str]]], | |
sample_times: int = 1, | |
steps: Optional[int] = None, | |
guidance_scale: Optional[float] = None, | |
eta: float = 0.0, | |
return_intermediates: bool = False, **kwargs): | |
if self.first_stage_model is None: | |
self.instantiate_first_stage(self.first_stage_config) | |
if steps is None: | |
steps = self.scheduler_cfg.num_inference_steps | |
if guidance_scale is None: | |
guidance_scale = self.scheduler_cfg.guidance_scale | |
do_classifier_free_guidance = guidance_scale > 0 | |
# conditional encode | |
xc = batch[self.cond_stage_key] | |
# cond = self.cond_stage_model[self.cond_stage_key](xc) | |
cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1) | |
if do_classifier_free_guidance: | |
""" | |
Note: There are two kinds of uncond for text. | |
1: using "" as uncond text; (in SAL diffusion) | |
2: zeros_like(cond) as uncond text; (in MDM) | |
""" | |
# un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) | |
un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond) | |
# un_cond = torch.zeros_like(cond, device=cond.device) | |
cond = torch.cat([un_cond, cond], dim=0) | |
outputs = [] | |
latents = None | |
if not return_intermediates: | |
for _ in range(sample_times): | |
sample_loop = ddim_sample( | |
self.denoise_scheduler, | |
self.model, | |
shape=self.first_stage_model.latent_shape, | |
cond=cond, | |
steps=steps, | |
guidance_scale=guidance_scale, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
device=self.device, | |
eta=eta, | |
disable_prog=not self.zero_rank | |
) | |
for sample, t in sample_loop: | |
latents = sample | |
outputs.append(self.decode_first_stage(latents, **kwargs)) | |
else: | |
sample_loop = ddim_sample( | |
self.denoise_scheduler, | |
self.model, | |
shape=self.first_stage_model.latent_shape, | |
cond=cond, | |
steps=steps, | |
guidance_scale=guidance_scale, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
device=self.device, | |
eta=eta, | |
disable_prog=not self.zero_rank | |
) | |
iter_size = steps // sample_times | |
i = 0 | |
for sample, t in sample_loop: | |
latents = sample | |
if i % iter_size == 0 or i == steps - 1: | |
outputs.append(self.decode_first_stage(latents, **kwargs)) | |
i += 1 | |
return outputs | |