Pusheen's picture
Upload 139 files
281df87 verified
raw
history blame
3.15 kB
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from ldm.util import default
from ldm.modules.diffusionmodules.util import extract_into_tensor
from .ddpm import DDPM
class LatentDiffusion(DDPM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# hardcoded
self.clip_denoised = False
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
"Does not support DDPM sampling anymore. Only do DDIM or PLMS"
# = = = = = = = = = = = = Below is for sampling = = = = = = = = = = = = #
# def predict_start_from_noise(self, x_t, t, noise):
# return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
# extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise )
# def q_posterior(self, x_start, x_t, t):
# posterior_mean = (
# extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
# extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
# )
# posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
# posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
# return posterior_mean, posterior_variance, posterior_log_variance_clipped
# def p_mean_variance(self, model, x, c, t):
# model_out = model(x, t, c)
# x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
# if self.clip_denoised:
# x_recon.clamp_(-1., 1.)
# model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
# return model_mean, posterior_variance, posterior_log_variance, x_recon
# @torch.no_grad()
# def p_sample(self, model, x, c, t):
# b, *_, device = *x.shape, x.device
# model_mean, _, model_log_variance, x0 = self.p_mean_variance(model, x=x, c=c, t=t, )
# noise = torch.randn_like(x)
# # no noise when t == 0
# nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
# return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
# @torch.no_grad()
# def p_sample_loop(self, model, shape, c):
# device = self.betas.device
# b = shape[0]
# img = torch.randn(shape, device=device)
# iterator = tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps)
# for i in iterator:
# ts = torch.full((b,), i, device=device, dtype=torch.long)
# img, x0 = self.p_sample(model, img, c, ts)
# return img
# @torch.no_grad()
# def sample(self, model, shape, c, uc=None, guidance_scale=None):
# return self.p_sample_loop(model, shape, c)