from modules.loader.module_loader import GenericModuleLoader from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams import torch from modules.params.diffusion.inference_params import InferenceParams from utils import result_processor from modules.loader.module_loader import GenericModuleLoader from tqdm import tqdm from PIL import Image, ImageFilter from utils.inference_utils import resize_and_crop,get_padding_for_aspect_ratio import numpy as np from safetensors.torch import load_file as load_safetensors import math from einops import repeat, rearrange from torchvision.transforms import ToTensor from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder import PIL from modules.params.vfi import VFIParams from modules.params.i2v_enhance import I2VEnhanceParams from typing import List,Union from models.diffusion.wrappers import StreamingWrapper from diffusion_trainer.abstract_trainer import AbstractTrainer from utils.loader import download_ckpt import torchvision.transforms.functional as TF from diffusers import AutoPipelineForInpainting, DEISMultistepScheduler from transformers import BlipProcessor, BlipForConditionalGeneration class StreamingSVD(AbstractTrainer): def __init__(self, module_loader: GenericModuleLoader, diff_trainer_params: DiffusionTrainerParams, inference_params: InferenceParams, vfi: VFIParams, i2v_enhance: I2VEnhanceParams, ): super().__init__(inference_params=inference_params, diff_trainer_params=diff_trainer_params, module_loader=module_loader, ) # network config is wrapped by OpenAIWrapper, so we dont need a direct reference anymore # this corresponds to the config yaml defined at model.module_loader.module_config.model.dependent_modules del self.network_config self.diff_trainer_params: DiffusionTrainerParams self.vfi = vfi self.i2v_enhance = i2v_enhance def on_inference_epoch_start(self): super().on_inference_epoch_start() # for StreamingSVD we use a model wrapper that combines the base SVD model and the control model. self.inference_model = StreamingWrapper( diffusion_model=self.model.diffusion_model, controlnet=self.controlnet, num_frame_conditioning=self.inference_params.num_conditional_frames ) def post_init(self): self.svd_pipeline.set_progress_bar_config(disable=True) if self.device.type != "cpu": self.svd_pipeline.enable_model_cpu_offload(gpu_id = self.device.index) # re-use the open clip already loaded for image conditioner for image_encoder_apm embedders = self.conditioner.embedders for embedder in embedders: if hasattr(embedder,"input_key") and embedder.input_key == "cond_frames_without_noise": self.image_encoder_apm = embedder.open_clip self.first_stage_model.to("cpu") self.conditioner.embedders[3].encoder.to("cpu") self.conditioner.embedders[0].open_clip.to("cpu") pipe = AutoPipelineForInpainting.from_pretrained( 'Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16", safety_checker=None, requires_safety_checker=False) pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to(self.device) pipe.enable_model_cpu_offload(gpu_id = self.device.index) self.inpaint_pipe = pipe processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-large") model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(self.device) def blip(x): return processor.decode(model.generate(** processor(x, return_tensors='pt').to("cuda", torch.float16))[0], skip_special_tokens=True) self.blip = blip # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py def get_unique_embedder_keys_from_conditioner(self, conditioner): return list(set([x.input_key for x in conditioner.embedders])) # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py def get_batch_sgm(self, keys, value_dict, N, T, device): batch = {} batch_uc = {} for key in keys: if key == "fps_id": batch[key] = ( torch.tensor([value_dict["fps_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "motion_bucket_id": batch[key] = ( torch.tensor([value_dict["motion_bucket_id"]]) .to(device) .repeat(int(math.prod(N))) ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to(device), "1 -> b", b=math.prod(N), ) elif key == "cond_frames": batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) elif key == "cond_frames_without_noise": batch[key] = repeat( value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] ) else: batch[key] = value_dict[key] if T is not None: batch["num_video_frames"] = T for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc # Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/models/diffusion.py @torch.no_grad() def decode_first_stage(self, z): self.first_stage_model.to(self.device) z = 1.0 / self.diff_trainer_params.scale_factor * z #n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_samples = min(z.shape[0],8) #print("SVD decoder started") import time start = time.time() n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] with torch.autocast("cuda", enabled=not self.diff_trainer_params.disable_first_stage_autocast): for n in range(n_rounds): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": len( z[n * n_samples: (n + 1) * n_samples])} else: kwargs = {} out = self.first_stage_model.decode( z[n * n_samples: (n + 1) * n_samples], **kwargs ) all_out.append(out) out = torch.cat(all_out, dim=0) # print(f"SVD decoder finished after {time.time()-start} seconds.") self.first_stage_model.to("cpu") return out # Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py def _generate_conditional_output(self, svd_input_frame, inference_params: InferenceParams, **params): C = 4 F = 8 # spatial compression TODO read from model H = svd_input_frame.shape[-2] W = svd_input_frame.shape[-1] num_frames = self.sampler.guider.num_frames shape = (num_frames, C, H // F, W // F) batch_size = 1 image = svd_input_frame[None,:] cond_aug = 0.02 value_dict = {} value_dict["motion_bucket_id"] = 127 value_dict["fps_id"] = 6 value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = image value_dict["cond_frames"] =image + cond_aug * torch.rand_like(image) batch, batch_uc = self.get_batch_sgm( self.get_unique_embedder_keys_from_conditioner( self.conditioner), value_dict, [1, num_frames], T=num_frames, device=self.device, ) self.conditioner.embedders[3].encoder.to(self.device) self.conditioner.embedders[0].open_clip.to(self.device) c, uc = self.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=[ "cond_frames", "cond_frames_without_noise", ], ) self.conditioner.embedders[3].encoder.to("cpu") self.conditioner.embedders[0].open_clip.to("cpu") for k in ["crossattn", "concat"]: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) randn = torch.randn(shape, device=self.device) additional_model_inputs = {} additional_model_inputs["image_only_indicator"] = torch.zeros(2*batch_size,num_frames).to(self.device) additional_model_inputs["num_video_frames"] = batch["num_video_frames"] # StreamingSVD inputs additional_model_inputs["batch_size"] = 2*batch_size additional_model_inputs["num_conditional_frames"] = self.inference_params.num_conditional_frames additional_model_inputs["ctrl_frames"] = params["ctrl_frames"] self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( self.device) self.inference_model.controlnet = self.inference_model.controlnet.to( self.device) c["vector"] = c["vector"].to(randn.dtype) uc["vector"] = uc["vector"].to(randn.dtype) def denoiser(input, sigma, c): return self.denoiser(self.inference_model,input,sigma,c, **additional_model_inputs) samples_z = self.sampler(denoiser,randn,cond=c,uc=uc) self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( "cpu") self.inference_model.controlnet = self.inference_model.controlnet.to("cpu") samples_x = self.decode_first_stage(samples_z) samples = torch.clamp(samples_x,min=-1.0,max=1.0) return samples def extract_anchor_frames(self, video, input_range,inference_params: InferenceParams): """ Extracts anchor frames from the input video based on the provided inference parameters. Parameters: - video: torch.Tensor The input video tensor. - input_range: list The pixel value range of input video. - inference_params: InferenceParams An object containing inference parameters. - anchor_frames: str Specifies how the anchor frames are encoded. It can be either a single number specifying which frame is used as the anchor frame, or a range in the format "a:b" indicating that frames from index a up to index b (inclusive) are used as anchor frames. Returns: - torch.Tensor The extracted anchor frames from the input video. """ video = result_processor.convert_range(video=video.clone(),input_range=input_range,output_range=[-1,1]) if video.shape[1] == 3 and video.shape[0]>3: video = rearrange(video,"F C W H -> 1 F C W H") elif video.shape[0]>3 and video.shape[-1] == 3: video = rearrange(video,"F W H C -> 1 F C W H") else: raise NotImplementedError(f"Unexpected video input format: {video.shape}") if ":" in inference_params.anchor_frames: anchor_frames = inference_params.anchor_frames.split(":") anchor_frames = [int(anchor_frame) for anchor_frame in anchor_frames] assert len(anchor_frames) == 2,"Anchor frames encoding wrong." anchor = video[:,anchor_frames[0]:anchor_frames[1]] else: anchor_frame = int(inference_params.anchor_frames) anchor = video[:, anchor_frame].unsqueeze(0) return anchor def extract_ctrl_frames(self,video: torch.FloatType, input_range: List[int], inference_params: InferenceParams): """ Extracts control frames from the input video. Parameters: - video: torch.Tensor The input video tensor. - input_range: list The pixel value range of input video. - inference_params: InferenceParams An object containing inference parameters. Returns: - torch.Tensor The extracted control image encoding frames from the input video. """ video = result_processor.convert_range(video=video.clone(), input_range=input_range, output_range=[-1, 1]) if video.shape[1] == 3 and video.shape[0] > 3: video = rearrange(video, "F C W H -> 1 F C W H") elif video.shape[0] > 3 and video.shape[-1] == 3: video = rearrange(video, "F W H C -> 1 F C W H") else: raise NotImplementedError( f"Unexpected video input format: {video.shape}") # return the last num_conditional_frames frames video = video[:, -inference_params.num_conditional_frames:] return video def _autoregressive_generation(self,initial_generation: Union[torch.FloatType,List[torch.FloatType]], inference_params:InferenceParams): """ Perform autoregressive generation of video chunks based on the initial generation and inference parameters. Parameters: - initial_generation: torch.Tensor or list of torch.Tensor The initial generation or list of initial generation video chunks. - inference_params: InferenceParams An object containing inference parameters. Returns: - torch.Tensor The generated video resulting from autoregressive generation. """ # input is [-1,1] float result_chunks = initial_generation if not isinstance(result_chunks,list): result_chunks = [result_chunks] # make sure if (result_chunks[0].shape[1] >3) and (result_chunks[0].shape[-1] == 3): result_chunks = [rearrange(result_chunks[0],"F W H C -> F C W H")] # generating chunk by conditioning on the previous chunks for _ in tqdm(list(range(inference_params.n_autoregressive_generations)),desc="StreamingSVD"): # extract anchor frames based on the entire, so far generated, video # note that we do note use anchor frame in StreamingSVD (apart from the anchor frame already used by SVD). anchor_frames = self.extract_anchor_frames( video = torch.cat(result_chunks), inference_params=inference_params, input_range=[-1, 1], ) # extract control frames based on the last generated chunk ctrl_frames = self.extract_ctrl_frames( video = result_chunks[-1], input_range=[-1, 1], inference_params=inference_params, ) # select the anchor frame for svd svd_input_frame = result_chunks[0][int(inference_params.anchor_frames)] # generate the next chunk # result is [F, C, H, W], range is [-1,1] float. result = self._generate_conditional_output( svd_input_frame = svd_input_frame, inference_params=inference_params, anchor_frames=anchor_frames, ctrl_frames=ctrl_frames, ) # from each generation, we keep all frames except for the first frames result = result[inference_params.num_conditional_frames:] result_chunks.append(result) torch.cuda.empty_cache() # concat all chunks to one long video result_chunks = [result_processor.convert_range(chunk,output_range=[0,255],input_range=[-1,1]) for chunk in result_chunks] result = result_processor.concat_chunks(result_chunks) torch.cuda.empty_cache() return result def ensure_image_ratio(self,source_image: PIL,target_aspect_ratio = 16/9): if source_image.width / source_image.height == target_aspect_ratio: return source_image, None image = source_image.copy().convert("RGBA") mask = image.split()[-1] image = image.convert("RGB") padding = get_padding_for_aspect_ratio(image) mask_padded = TF.pad(mask, padding) mask_padded_size = mask_padded.size mask_padded_resized = TF.resize(mask_padded, (512, 512), interpolation=TF.InterpolationMode.NEAREST) mask_padded_resized = TF.invert(mask_padded_resized) # image padded_input_image = TF.pad(image, padding, padding_mode="reflect") resized_image = TF.resize(padded_input_image, (512, 512)) image_tensor = (self.inpaint_pipe.image_processor.preprocess( resized_image).cuda().half()) latent_tensor = self.inpaint_pipe._encode_vae_image(image_tensor, None) self.inpaint_pipe.scheduler.set_timesteps(999) noisy_latent_tensor = self.inpaint_pipe.scheduler.add_noise( latent_tensor, torch.randn_like(latent_tensor), self.inpaint_pipe.scheduler.timesteps[:1], ) prompt = self.blip(source_image) if prompt.startswith("there is "): prompt = prompt[len("there is "):] output_image_normalized_size = self.inpaint_pipe( prompt=prompt, image=resized_image, mask_image=mask_padded_resized, latents=noisy_latent_tensor, ).images[0] output_image_extended_size = TF.resize( output_image_normalized_size, mask_padded_size[::-1]) blured_outpainting_mask = TF.invert(mask_padded).filter( ImageFilter.GaussianBlur(radius=5)) final_image = Image.composite( output_image_extended_size, padded_input_image, blured_outpainting_mask) return final_image, TF.invert(mask_padded) def image_to_video(self, batch, inference_params: InferenceParams, batch_idx): """ Performs image to video based on the input batch and inference parameters. It runs SVD-XT one to generate the first chunk, then auto-regressively applies StreamingSVD. Parameters: - batch: dict The input batch containing the start image for generating the video. - inference_params: InferenceParams An object containing inference parameters. - batch_idx: int The index of the batch. Returns: - torch.Tensor The generated video based on the image image. """ batch_key = "image" assert batch_key == "image", f"Generating video from {batch_key} not implemented." input_image = PIL.Image.fromarray(batch[batch_key][0].cpu().numpy()) # TODO remove conversion forth and back outpainted_image, _ = self.ensure_image_ratio(input_image) #image = Image.fromarray(np.uint8(image)) ''' if image.width/image.height != 16/9: print(f"Warning! For best results, we assume the aspect ratio of the input image to be 16:9. Found ratio {image.width}:{image.height}.") ''' scaled_outpainted_image, expanded_size = resize_and_crop(outpainted_image) assert scaled_outpainted_image.width == 1024 and scaled_outpainted_image.height == 576, f"Wrong shape for file {batch[batch_key]} with shape {scaled_outpainted_image.width}:{scaled_outpainted_image.height}." # Generating first chunk with torch.autocast(device_type="cuda",enabled=False): video_chunks = self.svd_pipeline( scaled_outpainted_image, decode_chunk_size=8).frames[0] video_chunks = torch.stack([ToTensor()(frame) for frame in video_chunks]) video_chunks = video_chunks * 2.0 - 1 # [-1,1], float video_chunks = video_chunks.to(self.device) video = self._autoregressive_generation( initial_generation=video_chunks, inference_params=inference_params) return video, scaled_outpainted_image, expanded_size def generate_output(self, batch, batch_idx,inference_params: InferenceParams): """ Generate output video based on the input batch and inference parameters. Parameters: - batch: dict The input batch containing data for generating the output video. - batch_idx: int The index of the batch. - inference_params: InferenceParams An object containing inference parameters. Returns: - torch.Tensor The generated video. Note the result is also accessible via self.trainer.generated_video """ sample_id = batch["sample_id"].item() video, scaled_outpainted_image, expanded_size = self.image_to_video( batch, inference_params=inference_params, batch_idx=sample_id) self.trainer.generated_video = video.numpy() self.trainer.expanded_size = expanded_size self.trainer.scaled_outpainted_image = scaled_outpainted_image return video