StreamingSVD / diffusion_trainer /streaming_svd.py
lev1's picture
Initial commit
8fd2f2f
from modules.loader.module_loader import GenericModuleLoader
from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams
import torch
from modules.params.diffusion.inference_params import InferenceParams
from utils import result_processor
from modules.loader.module_loader import GenericModuleLoader
from tqdm import tqdm
from PIL import Image, ImageFilter
from utils.inference_utils import resize_and_crop,get_padding_for_aspect_ratio
import numpy as np
from safetensors.torch import load_file as load_safetensors
import math
from einops import repeat, rearrange
from torchvision.transforms import ToTensor
from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder
import PIL
from modules.params.vfi import VFIParams
from modules.params.i2v_enhance import I2VEnhanceParams
from typing import List,Union
from models.diffusion.wrappers import StreamingWrapper
from diffusion_trainer.abstract_trainer import AbstractTrainer
from utils.loader import download_ckpt
import torchvision.transforms.functional as TF
from diffusers import AutoPipelineForInpainting, DEISMultistepScheduler
from transformers import BlipProcessor, BlipForConditionalGeneration
class StreamingSVD(AbstractTrainer):
def __init__(self,
module_loader: GenericModuleLoader,
diff_trainer_params: DiffusionTrainerParams,
inference_params: InferenceParams,
vfi: VFIParams,
i2v_enhance: I2VEnhanceParams,
):
super().__init__(inference_params=inference_params,
diff_trainer_params=diff_trainer_params,
module_loader=module_loader,
)
# network config is wrapped by OpenAIWrapper, so we dont need a direct reference anymore
# this corresponds to the config yaml defined at model.module_loader.module_config.model.dependent_modules
del self.network_config
self.diff_trainer_params: DiffusionTrainerParams
self.vfi = vfi
self.i2v_enhance = i2v_enhance
def on_inference_epoch_start(self):
super().on_inference_epoch_start()
# for StreamingSVD we use a model wrapper that combines the base SVD model and the control model.
self.inference_model = StreamingWrapper(
diffusion_model=self.model.diffusion_model,
controlnet=self.controlnet,
num_frame_conditioning=self.inference_params.num_conditional_frames
)
def post_init(self):
self.svd_pipeline.set_progress_bar_config(disable=True)
if self.device.type != "cpu":
self.svd_pipeline.enable_model_cpu_offload(gpu_id = self.device.index)
# re-use the open clip already loaded for image conditioner for image_encoder_apm
embedders = self.conditioner.embedders
for embedder in embedders:
if hasattr(embedder,"input_key") and embedder.input_key == "cond_frames_without_noise":
self.image_encoder_apm = embedder.open_clip
self.first_stage_model.to("cpu")
self.conditioner.embedders[3].encoder.to("cpu")
self.conditioner.embedders[0].open_clip.to("cpu")
pipe = AutoPipelineForInpainting.from_pretrained(
'Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16", safety_checker=None, requires_safety_checker=False)
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(self.device)
pipe.enable_model_cpu_offload(gpu_id = self.device.index)
self.inpaint_pipe = pipe
processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(self.device)
def blip(x): return processor.decode(model.generate(** processor(x,
return_tensors='pt').to("cuda", torch.float16))[0], skip_special_tokens=True)
self.blip = blip
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
def get_unique_embedder_keys_from_conditioner(self, conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
def get_batch_sgm(self, keys, value_dict, N, T, device):
batch = {}
batch_uc = {}
for key in keys:
if key == "fps_id":
batch[key] = (
torch.tensor([value_dict["fps_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to(device),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"],
"1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
)
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/models/diffusion.py
@torch.no_grad()
def decode_first_stage(self, z):
self.first_stage_model.to(self.device)
z = 1.0 / self.diff_trainer_params.scale_factor * z
#n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_samples = min(z.shape[0],8)
#print("SVD decoder started")
import time
start = time.time()
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.diff_trainer_params.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(
z[n * n_samples: (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples: (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
# print(f"SVD decoder finished after {time.time()-start} seconds.")
self.first_stage_model.to("cpu")
return out
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
def _generate_conditional_output(self, svd_input_frame, inference_params: InferenceParams, **params):
C = 4
F = 8 # spatial compression TODO read from model
H = svd_input_frame.shape[-2]
W = svd_input_frame.shape[-1]
num_frames = self.sampler.guider.num_frames
shape = (num_frames, C, H // F, W // F)
batch_size = 1
image = svd_input_frame[None,:]
cond_aug = 0.02
value_dict = {}
value_dict["motion_bucket_id"] = 127
value_dict["fps_id"] = 6
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames_without_noise"] = image
value_dict["cond_frames"] =image + cond_aug * torch.rand_like(image)
batch, batch_uc = self.get_batch_sgm(
self.get_unique_embedder_keys_from_conditioner(
self.conditioner),
value_dict,
[1, num_frames],
T=num_frames,
device=self.device,
)
self.conditioner.embedders[3].encoder.to(self.device)
self.conditioner.embedders[0].open_clip.to(self.device)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=[
"cond_frames",
"cond_frames_without_noise",
],
)
self.conditioner.embedders[3].encoder.to("cpu")
self.conditioner.embedders[0].open_clip.to("cpu")
for k in ["crossattn", "concat"]:
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
randn = torch.randn(shape, device=self.device)
additional_model_inputs = {}
additional_model_inputs["image_only_indicator"] = torch.zeros(2*batch_size,num_frames).to(self.device)
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
# StreamingSVD inputs
additional_model_inputs["batch_size"] = 2*batch_size
additional_model_inputs["num_conditional_frames"] = self.inference_params.num_conditional_frames
additional_model_inputs["ctrl_frames"] = params["ctrl_frames"]
self.inference_model.diffusion_model = self.inference_model.diffusion_model.to(
self.device)
self.inference_model.controlnet = self.inference_model.controlnet.to(
self.device)
c["vector"] = c["vector"].to(randn.dtype)
uc["vector"] = uc["vector"].to(randn.dtype)
def denoiser(input, sigma, c):
return self.denoiser(self.inference_model,input,sigma,c, **additional_model_inputs)
samples_z = self.sampler(denoiser,randn,cond=c,uc=uc)
self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( "cpu")
self.inference_model.controlnet = self.inference_model.controlnet.to("cpu")
samples_x = self.decode_first_stage(samples_z)
samples = torch.clamp(samples_x,min=-1.0,max=1.0)
return samples
def extract_anchor_frames(self, video, input_range,inference_params: InferenceParams):
"""
Extracts anchor frames from the input video based on the provided inference parameters.
Parameters:
- video: torch.Tensor
The input video tensor.
- input_range: list
The pixel value range of input video.
- inference_params: InferenceParams
An object containing inference parameters.
- anchor_frames: str
Specifies how the anchor frames are encoded. It can be either a single number specifying which frame is used as the anchor frame,
or a range in the format "a:b" indicating that frames from index a up to index b (inclusive) are used as anchor frames.
Returns:
- torch.Tensor
The extracted anchor frames from the input video.
"""
video = result_processor.convert_range(video=video.clone(),input_range=input_range,output_range=[-1,1])
if video.shape[1] == 3 and video.shape[0]>3:
video = rearrange(video,"F C W H -> 1 F C W H")
elif video.shape[0]>3 and video.shape[-1] == 3:
video = rearrange(video,"F W H C -> 1 F C W H")
else:
raise NotImplementedError(f"Unexpected video input format: {video.shape}")
if ":" in inference_params.anchor_frames:
anchor_frames = inference_params.anchor_frames.split(":")
anchor_frames = [int(anchor_frame) for anchor_frame in anchor_frames]
assert len(anchor_frames) == 2,"Anchor frames encoding wrong."
anchor = video[:,anchor_frames[0]:anchor_frames[1]]
else:
anchor_frame = int(inference_params.anchor_frames)
anchor = video[:, anchor_frame].unsqueeze(0)
return anchor
def extract_ctrl_frames(self,video: torch.FloatType, input_range: List[int], inference_params: InferenceParams):
"""
Extracts control frames from the input video.
Parameters:
- video: torch.Tensor
The input video tensor.
- input_range: list
The pixel value range of input video.
- inference_params: InferenceParams
An object containing inference parameters.
Returns:
- torch.Tensor
The extracted control image encoding frames from the input video.
"""
video = result_processor.convert_range(video=video.clone(), input_range=input_range, output_range=[-1, 1])
if video.shape[1] == 3 and video.shape[0] > 3:
video = rearrange(video, "F C W H -> 1 F C W H")
elif video.shape[0] > 3 and video.shape[-1] == 3:
video = rearrange(video, "F W H C -> 1 F C W H")
else:
raise NotImplementedError(
f"Unexpected video input format: {video.shape}")
# return the last num_conditional_frames frames
video = video[:, -inference_params.num_conditional_frames:]
return video
def _autoregressive_generation(self,initial_generation: Union[torch.FloatType,List[torch.FloatType]], inference_params:InferenceParams):
"""
Perform autoregressive generation of video chunks based on the initial generation and inference parameters.
Parameters:
- initial_generation: torch.Tensor or list of torch.Tensor
The initial generation or list of initial generation video chunks.
- inference_params: InferenceParams
An object containing inference parameters.
Returns:
- torch.Tensor
The generated video resulting from autoregressive generation.
"""
# input is [-1,1] float
result_chunks = initial_generation
if not isinstance(result_chunks,list):
result_chunks = [result_chunks]
# make sure
if (result_chunks[0].shape[1] >3) and (result_chunks[0].shape[-1] == 3):
result_chunks = [rearrange(result_chunks[0],"F W H C -> F C W H")]
# generating chunk by conditioning on the previous chunks
for _ in tqdm(list(range(inference_params.n_autoregressive_generations)),desc="StreamingSVD"):
# extract anchor frames based on the entire, so far generated, video
# note that we do note use anchor frame in StreamingSVD (apart from the anchor frame already used by SVD).
anchor_frames = self.extract_anchor_frames(
video = torch.cat(result_chunks),
inference_params=inference_params,
input_range=[-1, 1],
)
# extract control frames based on the last generated chunk
ctrl_frames = self.extract_ctrl_frames(
video = result_chunks[-1],
input_range=[-1, 1],
inference_params=inference_params,
)
# select the anchor frame for svd
svd_input_frame = result_chunks[0][int(inference_params.anchor_frames)]
# generate the next chunk
# result is [F, C, H, W], range is [-1,1] float.
result = self._generate_conditional_output(
svd_input_frame = svd_input_frame,
inference_params=inference_params,
anchor_frames=anchor_frames,
ctrl_frames=ctrl_frames,
)
# from each generation, we keep all frames except for the first <num_conditional_frames> frames
result = result[inference_params.num_conditional_frames:]
result_chunks.append(result)
torch.cuda.empty_cache()
# concat all chunks to one long video
result_chunks = [result_processor.convert_range(chunk,output_range=[0,255],input_range=[-1,1]) for chunk in result_chunks]
result = result_processor.concat_chunks(result_chunks)
torch.cuda.empty_cache()
return result
def ensure_image_ratio(self,source_image: PIL,target_aspect_ratio = 16/9):
if source_image.width / source_image.height == target_aspect_ratio:
return source_image, None
image = source_image.copy().convert("RGBA")
mask = image.split()[-1]
image = image.convert("RGB")
padding = get_padding_for_aspect_ratio(image)
mask_padded = TF.pad(mask, padding)
mask_padded_size = mask_padded.size
mask_padded_resized = TF.resize(mask_padded, (512, 512),
interpolation=TF.InterpolationMode.NEAREST)
mask_padded_resized = TF.invert(mask_padded_resized)
# image
padded_input_image = TF.pad(image, padding, padding_mode="reflect")
resized_image = TF.resize(padded_input_image, (512, 512))
image_tensor = (self.inpaint_pipe.image_processor.preprocess(
resized_image).cuda().half())
latent_tensor = self.inpaint_pipe._encode_vae_image(image_tensor, None)
self.inpaint_pipe.scheduler.set_timesteps(999)
noisy_latent_tensor = self.inpaint_pipe.scheduler.add_noise(
latent_tensor,
torch.randn_like(latent_tensor),
self.inpaint_pipe.scheduler.timesteps[:1],
)
prompt = self.blip(source_image)
if prompt.startswith("there is "):
prompt = prompt[len("there is "):]
output_image_normalized_size = self.inpaint_pipe(
prompt=prompt,
image=resized_image,
mask_image=mask_padded_resized,
latents=noisy_latent_tensor,
).images[0]
output_image_extended_size = TF.resize(
output_image_normalized_size, mask_padded_size[::-1])
blured_outpainting_mask = TF.invert(mask_padded).filter(
ImageFilter.GaussianBlur(radius=5))
final_image = Image.composite(
output_image_extended_size, padded_input_image, blured_outpainting_mask)
return final_image, TF.invert(mask_padded)
def image_to_video(self, batch, inference_params: InferenceParams, batch_idx):
"""
Performs image to video based on the input batch and inference parameters.
It runs SVD-XT one to generate the first chunk, then auto-regressively applies StreamingSVD.
Parameters:
- batch: dict
The input batch containing the start image for generating the video.
- inference_params: InferenceParams
An object containing inference parameters.
- batch_idx: int
The index of the batch.
Returns:
- torch.Tensor
The generated video based on the image image.
"""
batch_key = "image"
assert batch_key == "image", f"Generating video from {batch_key} not implemented."
input_image = PIL.Image.fromarray(batch[batch_key][0].cpu().numpy())
# TODO remove conversion forth and back
outpainted_image, _ = self.ensure_image_ratio(input_image)
#image = Image.fromarray(np.uint8(image))
'''
if image.width/image.height != 16/9:
print(f"Warning! For best results, we assume the aspect ratio of the input image to be 16:9. Found ratio {image.width}:{image.height}.")
'''
scaled_outpainted_image, expanded_size = resize_and_crop(outpainted_image)
assert scaled_outpainted_image.width == 1024 and scaled_outpainted_image.height == 576, f"Wrong shape for file {batch[batch_key]} with shape {scaled_outpainted_image.width}:{scaled_outpainted_image.height}."
# Generating first chunk
with torch.autocast(device_type="cuda",enabled=False):
video_chunks = self.svd_pipeline(
scaled_outpainted_image, decode_chunk_size=8).frames[0]
video_chunks = torch.stack([ToTensor()(frame) for frame in video_chunks])
video_chunks = video_chunks * 2.0 - 1 # [-1,1], float
video_chunks = video_chunks.to(self.device)
video = self._autoregressive_generation(
initial_generation=video_chunks,
inference_params=inference_params)
return video, scaled_outpainted_image, expanded_size
def generate_output(self, batch, batch_idx,inference_params: InferenceParams):
"""
Generate output video based on the input batch and inference parameters.
Parameters:
- batch: dict
The input batch containing data for generating the output video.
- batch_idx: int
The index of the batch.
- inference_params: InferenceParams
An object containing inference parameters.
Returns:
- torch.Tensor
The generated video. Note the result is also accessible via self.trainer.generated_video
"""
sample_id = batch["sample_id"].item()
video, scaled_outpainted_image, expanded_size = self.image_to_video(
batch, inference_params=inference_params, batch_idx=sample_id)
self.trainer.generated_video = video.numpy()
self.trainer.expanded_size = expanded_size
self.trainer.scaled_outpainted_image = scaled_outpainted_image
return video