import argparse import os from datetime import datetime from pathlib import Path from typing import List import av import numpy as np import torch import torchvision from diffusers import AutoencoderKL, DDIMScheduler from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline from einops import repeat from omegaconf import OmegaConf from PIL import Image from torchvision import transforms from transformers import CLIPVisionModelWithProjection from configs.prompts.test_cases import TestCasesDict from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.util import get_fps, read_frames, save_videos_grid def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--config") parser.add_argument("-W", type=int, default=512) parser.add_argument("-H", type=int, default=784) parser.add_argument("-L", type=int, default=24) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--cfg", type=float, default=3.5) parser.add_argument("--steps", type=int, default=30) parser.add_argument("--fps", type=int) args = parser.parse_args() return args def main(): args = parse_args() config = OmegaConf.load(args.config) if config.weight_dtype == "fp16": weight_dtype = torch.float16 else: weight_dtype = torch.float32 vae = AutoencoderKL.from_pretrained( config.pretrained_vae_path, ).to("cuda", dtype=weight_dtype) reference_unet = UNet2DConditionModel.from_pretrained( config.pretrained_base_model_path, subfolder="unet", ).to(dtype=weight_dtype, device="cuda") inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_pretrained_2d( config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=weight_dtype, device="cuda") pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( dtype=weight_dtype, device="cuda" ) image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path ).to(dtype=weight_dtype, device="cuda") sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) generator = torch.manual_seed(args.seed) width, height = args.W, args.H # load pretrained weights denoising_unet.load_state_dict( torch.load(config.denoising_unet_path, map_location="cpu"), strict=False, ) reference_unet.load_state_dict( torch.load(config.reference_unet_path, map_location="cpu"), ) pose_guider.load_state_dict( torch.load(config.pose_guider_path, map_location="cpu"), ) pipe = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider=pose_guider, scheduler=scheduler, ) pipe = pipe.to("cuda", dtype=weight_dtype) date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}" save_dir = Path(f"output/{date_str}/{save_dir_name}") save_dir.mkdir(exist_ok=True, parents=True) for ref_image_path in config["test_cases"].keys(): # Each ref_image may correspond to multiple actions for pose_video_path in config["test_cases"][ref_image_path]: ref_name = Path(ref_image_path).stem pose_name = Path(pose_video_path).stem.replace("_kps", "") ref_image_pil = Image.open(ref_image_path).convert("RGB") pose_list = [] pose_tensor_list = [] pose_images = read_frames(pose_video_path) src_fps = get_fps(pose_video_path) print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) for pose_image_pil in pose_images[: args.L]: pose_tensor_list.append(pose_transform(pose_image_pil)) pose_list.append(pose_image_pil) ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w) ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze( 0 ) # (1, c, 1, h, w) ref_image_tensor = repeat( ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=args.L ) pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) pose_tensor = pose_tensor.transpose(0, 1) pose_tensor = pose_tensor.unsqueeze(0) video = pipe( ref_image_pil, pose_list, width, height, args.L, args.steps, args.cfg, generator=generator, ).videos video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) save_videos_grid( video, f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.mp4", n_rows=3, fps=src_fps if args.fps is None else args.fps, ) if __name__ == "__main__": main()