|
from typing import List |
|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from .modules.diffusionmodules.util import ( |
|
make_beta_schedule, |
|
extract_into_tensor, |
|
enforce_zero_terminal_snr, |
|
noise_like, |
|
) |
|
from .util import exists, default, instantiate_from_config |
|
from .modules.distributions.distributions import DiagonalGaussianDistribution |
|
|
|
|
|
class DiffusionWrapper(nn.Module): |
|
def __init__(self, diffusion_model): |
|
super().__init__() |
|
self.diffusion_model = diffusion_model |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.diffusion_model(*args, **kwargs) |
|
|
|
|
|
class LatentDiffusionInterface(nn.Module): |
|
"""a simple interface class for LDM inference""" |
|
|
|
def __init__( |
|
self, |
|
unet_config, |
|
clip_config, |
|
vae_config, |
|
parameterization="eps", |
|
scale_factor=0.18215, |
|
beta_schedule="linear", |
|
timesteps=1000, |
|
linear_start=0.00085, |
|
linear_end=0.0120, |
|
cosine_s=8e-3, |
|
given_betas=None, |
|
zero_snr=False, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
unet = instantiate_from_config(unet_config) |
|
self.model = DiffusionWrapper(unet) |
|
self.clip_model = instantiate_from_config(clip_config) |
|
self.vae_model = instantiate_from_config(vae_config) |
|
|
|
self.parameterization = parameterization |
|
self.scale_factor = scale_factor |
|
self.register_schedule( |
|
given_betas=given_betas, |
|
beta_schedule=beta_schedule, |
|
timesteps=timesteps, |
|
linear_start=linear_start, |
|
linear_end=linear_end, |
|
cosine_s=cosine_s, |
|
zero_snr=zero_snr |
|
) |
|
|
|
def register_schedule( |
|
self, |
|
given_betas=None, |
|
beta_schedule="linear", |
|
timesteps=1000, |
|
linear_start=1e-4, |
|
linear_end=2e-2, |
|
cosine_s=8e-3, |
|
zero_snr=False |
|
): |
|
if exists(given_betas): |
|
betas = given_betas |
|
else: |
|
betas = make_beta_schedule( |
|
beta_schedule, |
|
timesteps, |
|
linear_start=linear_start, |
|
linear_end=linear_end, |
|
cosine_s=cosine_s, |
|
) |
|
if zero_snr: |
|
print("--- using zero snr---") |
|
betas = enforce_zero_terminal_snr(betas).numpy() |
|
alphas = 1.0 - betas |
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) |
|
|
|
(timesteps,) = betas.shape |
|
self.num_timesteps = int(timesteps) |
|
self.linear_start = linear_start |
|
self.linear_end = linear_end |
|
assert ( |
|
alphas_cumprod.shape[0] == self.num_timesteps |
|
), "alphas have to be defined for each timestep" |
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32) |
|
|
|
self.register_buffer("betas", to_torch(betas)) |
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) |
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) |
|
|
|
|
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) |
|
self.register_buffer( |
|
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) |
|
) |
|
self.register_buffer( |
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) |
|
) |
|
self.register_buffer( |
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) |
|
) |
|
self.register_buffer( |
|
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) |
|
) |
|
|
|
|
|
self.v_posterior = 0 |
|
posterior_variance = (1 - self.v_posterior) * betas * ( |
|
1.0 - alphas_cumprod_prev |
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas |
|
|
|
self.register_buffer("posterior_variance", to_torch(posterior_variance)) |
|
|
|
self.register_buffer( |
|
"posterior_log_variance_clipped", |
|
to_torch(np.log(np.maximum(posterior_variance, 1e-20))), |
|
) |
|
self.register_buffer( |
|
"posterior_mean_coef1", |
|
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), |
|
) |
|
self.register_buffer( |
|
"posterior_mean_coef2", |
|
to_torch( |
|
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) |
|
), |
|
) |
|
|
|
def q_sample(self, x_start, t, noise=None): |
|
noise = default(noise, lambda: torch.randn_like(x_start)) |
|
return ( |
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start |
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) |
|
* noise |
|
) |
|
|
|
def get_v(self, x, noise, t): |
|
return ( |
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise |
|
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x |
|
) |
|
|
|
def predict_start_from_noise(self, x_t, t, noise): |
|
return ( |
|
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t |
|
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
|
* noise |
|
) |
|
|
|
def predict_start_from_z_and_v(self, x_t, t, v): |
|
return ( |
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t |
|
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v |
|
) |
|
|
|
def predict_eps_from_z_and_v(self, x_t, t, v): |
|
return ( |
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v |
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) |
|
* x_t |
|
) |
|
|
|
def apply_model(self, x_noisy, t, cond, **kwargs): |
|
assert isinstance(cond, dict), "cond has to be a dictionary" |
|
return self.model(x_noisy, t, **cond, **kwargs) |
|
|
|
def get_learned_conditioning(self, prompts: List[str]): |
|
return self.clip_model(prompts) |
|
|
|
def get_learned_image_conditioning(self, images): |
|
return self.clip_model.forward_image(images) |
|
|
|
def get_first_stage_encoding(self, encoder_posterior): |
|
if isinstance(encoder_posterior, DiagonalGaussianDistribution): |
|
z = encoder_posterior.sample() |
|
elif isinstance(encoder_posterior, torch.Tensor): |
|
z = encoder_posterior |
|
else: |
|
raise NotImplementedError( |
|
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" |
|
) |
|
return self.scale_factor * z |
|
|
|
def encode_first_stage(self, x): |
|
return self.vae_model.encode(x) |
|
|
|
def decode_first_stage(self, z): |
|
z = 1.0 / self.scale_factor * z |
|
return self.vae_model.decode(z) |
|
|