Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import torch | |
from tqdm import tqdm | |
from typing import Tuple, List, Union, Optional | |
from diffusers.schedulers import DDIMScheduler | |
__all__ = ["ddim_sample"] | |
def ddim_sample(ddim_scheduler: DDIMScheduler, | |
diffusion_model: torch.nn.Module, | |
shape: Union[List[int], Tuple[int]], | |
cond: torch.FloatTensor, | |
steps: int, | |
eta: float = 0.0, | |
guidance_scale: float = 3.0, | |
do_classifier_free_guidance: bool = True, | |
generator: Optional[torch.Generator] = None, | |
device: torch.device = "cuda:0", | |
disable_prog: bool = True): | |
assert steps > 0, f"{steps} must > 0." | |
# init latents | |
bsz = cond.shape[0] | |
if do_classifier_free_guidance: | |
bsz = bsz // 2 | |
latents = torch.randn( | |
(bsz, *shape), | |
generator=generator, | |
device=cond.device, | |
dtype=cond.dtype, | |
) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * ddim_scheduler.init_noise_sigma | |
# set timesteps | |
ddim_scheduler.set_timesteps(steps) | |
timesteps = ddim_scheduler.timesteps.to(device) | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, and between [0, 1] | |
extra_step_kwargs = { | |
"eta": eta, | |
"generator": generator | |
} | |
# reverse | |
for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = ( | |
torch.cat([latents] * 2) | |
if do_classifier_free_guidance | |
else latents | |
) | |
# latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) | |
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) | |
noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# text_embeddings_for_guidance = encoder_hidden_states.chunk( | |
# 2)[1] if do_classifier_free_guidance else encoder_hidden_states | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = ddim_scheduler.step( | |
noise_pred, t, latents, **extra_step_kwargs | |
).prev_sample | |
yield latents, t | |
def karra_sample(): | |
pass | |