import os import torch import einops from diffusers import DiffusionPipeline from transformers import CLIPTextModel, CLIPTokenizer from huggingface_hub import snapshot_download from diffusers_vdm.vae import VideoAutoencoderKL from diffusers_vdm.projection import Resampler from diffusers_vdm.unet import UNet3DModel from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR class LatentVideoDiffusionPipeline(DiffusionPipeline): def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True): super().__init__() self.loading_components = dict( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, image_encoder=image_encoder, image_projection=image_projection ) for k, v in self.loading_components.items(): setattr(self, k, v) if fp16: self.vae.half() self.text_encoder.half() self.unet.half() self.image_encoder.half() self.image_projection.half() self.vae.requires_grad_(False) self.text_encoder.requires_grad_(False) self.image_encoder.requires_grad_(False) self.vae.eval() self.text_encoder.eval() self.image_encoder.eval() if eval: self.unet.eval() self.image_projection.eval() else: self.unet.train() self.image_projection.train() def to(self, *args, **kwargs): for k, v in self.loading_components.items(): if hasattr(v, 'to'): v.to(*args, **kwargs) return self def save_pretrained(self, save_directory, **kwargs): for k, v in self.loading_components.items(): folder = os.path.join(save_directory, k) os.makedirs(folder, exist_ok=True) v.save_pretrained(folder) return @classmethod def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None): local_folder = snapshot_download(repo_id=repo_id, token=token) return cls( tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")), text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")), image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")), vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")), image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")), unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")), fp16=fp16, eval=eval ) @torch.inference_mode() def encode_cropped_prompt_77tokens(self, prompt: str): cond_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt").input_ids.to(self.text_encoder.device) cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state return cond @torch.inference_mode() def encode_clip_vision(self, frames): b, c, t, h, w = frames.shape frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w') clipvision_embed = self.image_encoder(frames).last_hidden_state clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t) return clipvision_embed @torch.inference_mode() def encode_latents(self, videos, return_hidden_states=True): b, c, t, h, w = videos.shape x = einops.rearrange(videos, 'b c t h w -> (b t) c h w') encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states) z = encoder_posterior.mode() * self.vae.scale_factor z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) if not return_hidden_states: return z hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states] hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last return z, hidden_states @torch.inference_mode() def decode_latents(self, latents, hidden_states): B, C, T, H, W = latents.shape latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w') latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T) pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T) return pixels @torch.inference_mode() def __call__( self, batch_size: int = 1, steps: int = 50, guidance_scale: float = 5.0, positive_text_cond = None, negative_text_cond = None, positive_image_cond = None, negative_image_cond = None, concat_cond = None, fs = 3, progress_tqdm = None, ): unet_is_training = self.unet.training if unet_is_training: self.unet.eval() device = self.unet.device dtype = self.unet.dtype dynamic_tsnr_model = SamplerDynamicTSNR(self.unet) # Batch concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) if isinstance(fs, torch.Tensor): fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b else: fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b # Initial latents latent_shape = concat_cond.shape # Feeds sampler_kwargs = dict( cfg_scale=guidance_scale, positive=dict( context_text=positive_text_cond, context_img=positive_image_cond, fs=fs, concat_cond=concat_cond ), negative=dict( context_text=negative_text_cond, context_img=negative_image_cond, fs=fs, concat_cond=concat_cond ) ) # Sample results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm) if unet_is_training: self.unet.train() return results