|
|
|
|
|
import torch |
|
from tqdm import tqdm |
|
from typing import Tuple, List, Union, Optional |
|
from diffusers.schedulers import DDIMScheduler |
|
|
|
|
|
__all__ = ["ddim_sample"] |
|
|
|
|
|
def ddim_sample(ddim_scheduler: DDIMScheduler, |
|
diffusion_model: torch.nn.Module, |
|
shape: Union[List[int], Tuple[int]], |
|
cond: torch.FloatTensor, |
|
steps: int, |
|
eta: float = 0.0, |
|
guidance_scale: float = 3.0, |
|
do_classifier_free_guidance: bool = True, |
|
generator: Optional[torch.Generator] = None, |
|
device: torch.device = "cuda:0", |
|
disable_prog: bool = True): |
|
|
|
assert steps > 0, f"{steps} must > 0." |
|
|
|
|
|
bsz = cond.shape[0] |
|
if do_classifier_free_guidance: |
|
bsz = bsz // 2 |
|
|
|
latents = torch.randn( |
|
(bsz, *shape), |
|
generator=generator, |
|
device=cond.device, |
|
dtype=cond.dtype, |
|
) |
|
|
|
latents = latents * ddim_scheduler.init_noise_sigma |
|
|
|
ddim_scheduler.set_timesteps(steps) |
|
timesteps = ddim_scheduler.timesteps.to(device) |
|
|
|
|
|
extra_step_kwargs = { |
|
"eta": eta, |
|
"generator": generator |
|
} |
|
|
|
|
|
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): |
|
|
|
latent_model_input = ( |
|
torch.cat([latents] * 2) |
|
if do_classifier_free_guidance |
|
else latents |
|
) |
|
|
|
|
|
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) |
|
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) |
|
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
|
|
|
|
latents = ddim_scheduler.step( |
|
noise_pred, t, latents, **extra_step_kwargs |
|
).prev_sample |
|
|
|
yield latents, t |
|
|
|
|
|
def karra_sample(): |
|
pass |
|
|