Maikou's picture
all files first commit
9c3a994
# -*- coding: utf-8 -*-
import torch
from tqdm import tqdm
from typing import Tuple, List, Union, Optional
from diffusers.schedulers import DDIMScheduler
__all__ = ["ddim_sample"]
def ddim_sample(ddim_scheduler: DDIMScheduler,
diffusion_model: torch.nn.Module,
shape: Union[List[int], Tuple[int]],
cond: torch.FloatTensor,
steps: int,
eta: float = 0.0,
guidance_scale: float = 3.0,
do_classifier_free_guidance: bool = True,
generator: Optional[torch.Generator] = None,
device: torch.device = "cuda:0",
disable_prog: bool = True):
assert steps > 0, f"{steps} must > 0."
# init latents
bsz = cond.shape[0]
if do_classifier_free_guidance:
bsz = bsz // 2
latents = torch.randn(
(bsz, *shape),
generator=generator,
device=cond.device,
dtype=cond.dtype,
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * ddim_scheduler.init_noise_sigma
# set timesteps
ddim_scheduler.set_timesteps(steps)
timesteps = ddim_scheduler.timesteps.to(device)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, and between [0, 1]
extra_step_kwargs = {
"eta": eta,
"generator": generator
}
# reverse
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
if do_classifier_free_guidance
else latents
)
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# text_embeddings_for_guidance = encoder_hidden_states.chunk(
# 2)[1] if do_classifier_free_guidance else encoder_hidden_states
# compute the previous noisy sample x_t -> x_t-1
latents = ddim_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
yield latents, t
def karra_sample():
pass