Spaces:
Running
on
Zero
Running
on
Zero
from src.utils import * | |
from src.flow_utils import warp_tensor | |
import torch | |
import torchvision | |
import gc | |
""" | |
========================================================================== | |
* step(): one DDPM step with background smoothing | |
* inference(): translate one batch with FRESCO and background smoothing | |
========================================================================== | |
""" | |
def step(pipe, model_output, timestep, sample, generator, repeat_noise=False, | |
visualize_pipeline=False, flows=None, occs=None, saliency=None): | |
""" | |
DDPM step with background smoothing | |
* background smoothing: warp the background region of the previous frame to the current frame | |
""" | |
scheduler = pipe.scheduler | |
# 1. get previous step value (=t-1) | |
prev_timestep = scheduler.previous_timestep(timestep) | |
# 2. compute alphas, betas | |
alpha_prod_t = scheduler.alphas_cumprod[timestep] | |
alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.one | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
current_beta_t = 1 - current_alpha_t | |
# 3. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
""" | |
[HACK] add background smoothing | |
decode the feature | |
warp the feature of f_{i-1} | |
fuse the warped f_{i-1} with f_{i} in the non-salient region (i.e., background) | |
encode the fused feature | |
""" | |
if saliency is not None and flows is not None and occs is not None: | |
image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample | |
image = warp_tensor(image, flows, occs, saliency, unet_chunk_size=1) | |
pred_original_sample = pipe.vae.config.scaling_factor * pipe.vae.encode(image).latent_dist.sample() | |
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t | |
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
# 5. Compute predicted previous sample µ_t | |
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf | |
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample | |
variance = beta_prod_t_prev / beta_prod_t * current_beta_t | |
variance = torch.clamp(variance, min=1e-20) | |
variance = (variance ** 0.5) * torch.randn(model_output.shape, generator=generator, | |
device=model_output.device, dtype=model_output.dtype) | |
""" | |
[HACK] background smoothing | |
applying the same noise could be good for static background | |
""" | |
if repeat_noise: | |
variance = variance[0:1].repeat(model_output.shape[0],1,1,1) | |
if visualize_pipeline: # for debug | |
image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample | |
viz = torchvision.utils.make_grid(torch.clamp(image, -1, 1), image.shape[0], 1) | |
visualize(viz.cpu(), 90) | |
pred_prev_sample = pred_prev_sample + variance | |
return (pred_prev_sample, pred_original_sample) | |
def inference(pipe, controlnet, frescoProc, | |
imgs, prompt_embeds, edges, timesteps, | |
cond_scale=[0.7]*20, num_inference_steps=20, num_warmup_steps=6, | |
do_classifier_free_guidance=True, seed=0, guidance_scale=7.5, use_controlnet=True, | |
record_latents=[], propagation_mode=False, visualize_pipeline=False, | |
flows = None, occs = None, saliency=None, repeat_noise=False, | |
num_intraattn_steps = 1, step_interattn_end = 350, bg_smoothing_steps = [16,17]): | |
""" | |
video-to-video translation inference pipeline with FRESCO | |
* add controlnet and SDEdit | |
* add FRESCO-guided attention | |
* add FRESCO-guided optimization | |
* add background smoothing | |
* add support for inter-batch long video translation | |
[input of the original pipe] | |
pipe: base diffusion model | |
imgs: a batch of the input frames | |
prompt_embeds: prompts | |
num_inference_steps: number of DDPM steps | |
timesteps: generated by pipe.scheduler.set_timesteps(num_inference_steps) | |
do_classifier_free_guidance: cfg, should be always true | |
guidance_scale: cfg scale | |
seed | |
[input of SDEdit] | |
num_warmup_steps: skip the first num_warmup_steps DDPM steps | |
[input of controlnet] | |
use_controlnet: bool, whether using controlnet | |
controlnet: controlnet model | |
edges: input for controlnet (edge/stroke/depth, etc.) | |
cond_scale: controlnet scale | |
[input of FRESCO] | |
frescoProc: FRESCO attention controller | |
flows: optical flows | |
occs: occlusion mask | |
num_intraattn_steps: apply num_interattn_steps steps of spatial-guided attention | |
step_interattn_end: apply temporal-guided attention in [step_interattn_end, 1000] steps | |
[input for background smoothing] | |
saliency: saliency mask | |
repeat_noise: bool, use the same noise for all frames | |
bg_smoothing_steps: apply background smoothing in bg_smoothing_steps | |
[input for long video translation] | |
record_latents: recorded latents in the last batch | |
propagation_mode: bool, whether this is not the first batch | |
[output] | |
latents: a batch of latents of the translated frames | |
""" | |
gc.collect() | |
torch.cuda.empty_cache() | |
device = pipe._execution_device | |
noise_scheduler = pipe.scheduler | |
generator = torch.Generator(device=device).manual_seed(seed) | |
B, C, H, W = imgs.shape | |
latents = pipe.prepare_latents( | |
B, | |
pipe.unet.config.in_channels, | |
H, | |
W, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents = None, | |
) | |
if repeat_noise: | |
latents = latents[0:1].repeat(B,1,1,1).detach() | |
if num_warmup_steps < 0: | |
latents_init = latents.detach() | |
num_warmup_steps = 0 | |
else: | |
# SDEdit, use the noisy latent of imges as the input rather than a pure gausssian noise | |
latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample() | |
latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach() | |
# SDEdit, run num_inference_steps-num_warmup_steps steps | |
with pipe.progress_bar(total=num_inference_steps-num_warmup_steps) as progress_bar: | |
latents = latents_init | |
for i, t in enumerate(timesteps[num_warmup_steps:]): | |
""" | |
[HACK] control the steps to apply spatial/temporal-guided attention | |
[HACK] record and restore latents from previous batch | |
""" | |
if i >= num_intraattn_steps: | |
frescoProc.controller.disable_intraattn() | |
if t < step_interattn_end: | |
frescoProc.controller.disable_interattn() | |
if propagation_mode: # restore latent from previous batch and record latent of the current batch | |
latents[0:2] = record_latents[i].detach().clone() | |
record_latents[i] = latents[[0,len(latents)-1]].detach().clone() | |
else: # frist batch, record_latents[0][t] = [x_1,t, x_{N,t}] | |
record_latents += [latents[[0,len(latents)-1]].detach().clone()] | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
if use_controlnet: | |
control_model_input = latent_model_input | |
controlnet_prompt_embeds = prompt_embeds | |
down_block_res_samples, mid_block_res_sample = controlnet( | |
control_model_input, | |
t, | |
encoder_hidden_states=controlnet_prompt_embeds, | |
controlnet_cond=edges, | |
conditioning_scale=cond_scale[i+num_warmup_steps], | |
guess_mode=False, | |
return_dict=False, | |
) | |
else: | |
down_block_res_samples, mid_block_res_sample = None, None | |
# predict the noise residual | |
noise_pred = pipe.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=prompt_embeds, | |
cross_attention_kwargs=None, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample, | |
return_dict=False, | |
)[0] | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
""" | |
[HACK] background smoothing | |
Note: bg_smoothing_steps should be rescaled based on num_inference_steps | |
current [16,17] is based on num_inference_steps=20 | |
""" | |
if i + num_warmup_steps in bg_smoothing_steps: | |
latents = step(pipe, noise_pred, t, latents, generator, | |
visualize_pipeline=visualize_pipeline, | |
flows = flows, occs = occs, saliency=saliency)[0] | |
else: | |
latents = step(pipe, noise_pred, t, latents, generator, | |
visualize_pipeline=visualize_pipeline)[0] | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipe.scheduler.order == 0): | |
progress_bar.update() | |
return latents |