from __future__ import annotations import gc import pathlib import sys import tempfile import os import gradio as gr import imageio import PIL.Image import torch from diffusers.utils.import_utils import is_xformers_available from einops import rearrange from huggingface_hub import ModelCard from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, PNDMScheduler, ControlNetModel, PriorTransformer, UnCLIPScheduler from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from omegaconf import OmegaConf from typing import Any, Callable, Dict, List, Optional, Union, Tuple sys.path.append('Make-A-Protagonist') from makeaprotagonist.models.unet import UNet3DConditionModel from makeaprotagonist.pipelines.pipeline_stable_unclip_controlavideo import MakeAProtagonistStableUnCLIPPipeline, MultiControlNetModel from makeaprotagonist.dataset.dataset import MakeAProtagonistDataset from makeaprotagonist.util import save_videos_grid, ddim_inversion_unclip, ddim_inversion_prior from experts.grounded_sam_mask_out import mask_out_reference_image import ipdb class InferencePipeline: def __init__(self, hf_token: str | None = None): self.hf_token = hf_token self.pipe = None self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.model_id = None self.conditions = None self.masks = None self.ddim_inv_latent = None self.train_dataset, self.sample_indices = None, None def clear(self) -> None: self.model_id = None del self.pipe self.pipe = None self.conditions = None self.masks = None self.ddim_inv_latent = None self.train_dataset, self.sample_indices = None, None torch.cuda.empty_cache() gc.collect() @staticmethod def check_if_model_is_local(model_id: str) -> bool: return pathlib.Path(model_id).exists() @staticmethod def get_model_card(model_id: str, hf_token: str | None = None) -> ModelCard: if InferencePipeline.check_if_model_is_local(model_id): card_path = (pathlib.Path(model_id) / 'README.md').as_posix() else: card_path = model_id return ModelCard.load(card_path, token=hf_token) @staticmethod def get_base_model_info(model_id: str, hf_token: str | None = None) -> str: card = InferencePipeline.get_model_card(model_id, hf_token) return card.data.base_model @torch.no_grad() def load_pipe(self, model_id: str, n_steps, seed) -> None: if model_id == self.model_id: return self.conditions, self.masks, self.ddim_inv_latent, self.train_dataset, self.sample_indices base_model_id = self.get_base_model_info(model_id, self.hf_token) pretrained_model_path = 'stabilityai/stable-diffusion-2-1-unclip-small' # image encoding components feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor") image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_path, subfolder="image_encoder") # image noising components image_normalizer = StableUnCLIPImageNormalizer.from_pretrained(pretrained_model_path, subfolder="image_normalizer", torch_dtype=torch.float16,) image_noising_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="image_noising_scheduler") # regular denoising components tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16,) vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae", torch_dtype=torch.float16,) self.ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler') self.ddim_inv_scheduler.set_timesteps(n_steps) prior_model_id = "kakaobrain/karlo-v1-alpha" data_type = torch.float16 prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type) prior_text_model_id = "openai/clip-vit-large-patch14" prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id) prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type) prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler") prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) controlnet_model_id = ['controlnet-2-1-unclip-small-openposefull', 'controlnet-2-1-unclip-small-depth'] controlnet = MultiControlNetModel( [ControlNetModel.from_pretrained('Make-A-Protagonist/controlnet-2-1-unclip-small', subfolder=subfolder_id, torch_dtype=torch.float16) for subfolder_id in controlnet_model_id] ) unet = UNet3DConditionModel.from_pretrained( model_id, subfolder='unet', torch_dtype=torch.float16, use_auth_token=self.hf_token) # Freeze vae and text_encoder and adapter vae.requires_grad_(False) text_encoder.requires_grad_(False) ## freeze image embed image_encoder.requires_grad_(False) unet.requires_grad_(False) ## freeze controlnet controlnet.requires_grad_(False) ## freeze prior prior.requires_grad_(False) prior_text_model.requires_grad_(False) config_file = os.path.join('Make-A-Protagonist/configs', model_id.split('/')[-1] + '.yaml') self.cfg = OmegaConf.load(config_file) # def source_parsing(self, n_steps): # ipdb.set_trace() train_dataset = MakeAProtagonistDataset(**self.cfg) train_dataset.preprocess_img_embedding(feature_extractor, image_encoder) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, num_workers=0, ) image_encoder.to(dtype=data_type) pipe = MakeAProtagonistStableUnCLIPPipeline( prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_model, prior=prior, prior_scheduler=prior_scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") ) pipe = pipe.to(self.device) if is_xformers_available(): pipe.unet.enable_xformers_memory_efficient_attention() pipe.controlnet.enable_xformers_memory_efficient_attention() self.pipe = pipe self.model_id = model_id # type: ignore self.vae = vae # self.feature_extractor = feature_extractor # self.image_encoder = image_encoder ## ddim inverse for source video batch = next(iter(train_dataloader)) weight_dtype = torch.float16 pixel_values = batch["pixel_values"].to(weight_dtype).to(self.device) video_length = pixel_values.shape[1] pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") latents = self.vae.encode(pixel_values).latent_dist.sample() latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) latents = latents * self.vae.config.scaling_factor # ControlNet # ipdb.set_trace() conditions = [_condition.to(weight_dtype).to(self.device) for _, _condition in batch["conditions"].items()] # b f c h w masks = batch["masks"].to(weight_dtype).to(self.device) # b,f,1,h,w emb_dim = train_dataset.img_embeddings[0].size(0) key_frame_embed = torch.zeros((1, emb_dim)).to(device=latents.device, dtype=latents.dtype) ## this is dim 0 # ipdb.set_trace() ddim_inv_latent = ddim_inversion_unclip( self.pipe, self.ddim_inv_scheduler, video_latent=latents, num_inv_steps=n_steps, prompt="", image_embed=key_frame_embed, noise_level=0, seed=seed)[-1].to(weight_dtype) self.conditions = conditions self.masks = masks self.ddim_inv_latent = ddim_inv_latent self.train_dataset = train_dataset self.sample_indices = batch["sample_indices"][0] return conditions, masks, ddim_inv_latent, train_dataset, batch["sample_indices"][0] def run( self, model_id: str, prompt: str, video_length: int, fps: int, seed: int, n_steps: int, guidance_scale: float, ref_image: PIL.Image.Image, ref_pro_prompt: str, noise_level: int, start_step: int, control_pose: float, control_depth: float, source_pro: int = 0, # 0 or 1 source_bg: int = 0, ) -> PIL.Image.Image: if not torch.cuda.is_available(): raise gr.Error('CUDA is not available.') torch.cuda.empty_cache() conditions, masks, ddim_inv_latent, _, _ = self.load_pipe(model_id, n_steps, seed) ## conditions [1,F,3,H,W] ## masks [1,F,1,H,W] ## ddim_inv_latent [1,4,F,H,W] ## NOTE this is to deal with video length conditions = [_condition[:,:video_length] for _condition in conditions] masks = masks[:, :video_length] ddim_inv_latent = ddim_inv_latent[:,:,:video_length] generator = torch.Generator(device=self.device).manual_seed(seed) ## TODO mask out reference image # ipdb.set_trace() ref_image = mask_out_reference_image(ref_image, ref_pro_prompt) controlnet_conditioning_scale = [control_pose, control_depth] prior_denoised_embeds = None image_embed = None if source_bg: ## using source background and changing the protagonist prior_denoised_embeds = self.train_dataset.img_embeddings[0][None].to(device=ddim_inv_latent.device, dtype=ddim_inv_latent.dtype) # 1, 768 for UnCLIP-small if source_pro: # using source protagonist and changing the background sample_indices = self.sample_indices image_embed = [self.train_dataset.img_embeddings[idx] for idx in sample_indices] image_embed = torch.stack(image_embed, dim=0).to(device=ddim_inv_latent.device, dtype=ddim_inv_latent.dtype) # F, 768 for UnCLIP-small # F,C image_embed = image_embed[:video_length] ref_image = None # ipdb.set_trace() out = self.pipe( image=ref_image, prompt=prompt, control_image=conditions, video_length=video_length, width=768, height=768, num_inference_steps=n_steps, guidance_scale=guidance_scale, generator=generator, ## ddim inversion latents=ddim_inv_latent, ## ref image embeds noise_level=noise_level, ## controlnet controlnet_conditioning_scale=controlnet_conditioning_scale, ## mask masks=masks, mask_mode='all', mask_latent_fuse_mode = 'all', start_step=start_step, ## edit bg and pro prior_latents=None, image_embeds=image_embed, # keep pro prior_denoised_embeds=prior_denoised_embeds # keep bg ) frames = rearrange(out.videos[0], 'c t h w -> t h w c') frames = (frames * 255).to(torch.uint8).numpy() out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) writer = imageio.get_writer(out_file.name, fps=fps) for frame in frames: writer.append_data(frame) writer.close() return out_file.name