zejunyang commited on
Commit
c7a4aba
1 Parent(s): 2de857a
Files changed (1) hide show
  1. src/create_modules.py +96 -0
src/create_modules.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ffmpeg
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ from scipy.spatial.transform import Rotation as R
9
+ from scipy.interpolate import interp1d
10
+
11
+ from diffusers import AutoencoderKL, DDIMScheduler
12
+ from einops import repeat
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from transformers import CLIPVisionModelWithProjection
17
+
18
+
19
+ from src.models.pose_guider import PoseGuider
20
+ from src.models.unet_2d_condition import UNet2DConditionModel
21
+ from src.models.unet_3d import UNet3DConditionModel
22
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
23
+ from src.utils.util import save_videos_grid
24
+
25
+ from src.audio_models.model import Audio2MeshModel
26
+ from src.utils.audio_util import prepare_audio_feature
27
+ from src.utils.mp_utils import LMKExtractor
28
+ from src.utils.draw_util import FaceMeshVisualizer
29
+ from src.utils.pose_util import project_points
30
+
31
+
32
+ lmk_extractor = LMKExtractor()
33
+ vis = FaceMeshVisualizer(forehead_edge=False)
34
+
35
+ config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
36
+
37
+ if config.weight_dtype == "fp16":
38
+ weight_dtype = torch.float16
39
+ else:
40
+ weight_dtype = torch.float32
41
+
42
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
43
+ # prepare model
44
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
45
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
46
+ a2m_model.cuda().eval()
47
+
48
+ vae = AutoencoderKL.from_pretrained(
49
+ config.pretrained_vae_path,
50
+ ).to("cuda", dtype=weight_dtype)
51
+
52
+ reference_unet = UNet2DConditionModel.from_pretrained(
53
+ config.pretrained_base_model_path,
54
+ subfolder="unet",
55
+ ).to(dtype=weight_dtype, device="cuda")
56
+
57
+ inference_config_path = config.inference_config
58
+ infer_config = OmegaConf.load(inference_config_path)
59
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
60
+ config.pretrained_base_model_path,
61
+ config.motion_module_path,
62
+ subfolder="unet",
63
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
64
+ ).to(dtype=weight_dtype, device="cuda")
65
+
66
+
67
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
68
+
69
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
70
+ config.image_encoder_path
71
+ ).to(dtype=weight_dtype, device="cuda")
72
+
73
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
74
+ scheduler = DDIMScheduler(**sched_kwargs)
75
+
76
+ # load pretrained weights
77
+ denoising_unet.load_state_dict(
78
+ torch.load(config.denoising_unet_path, map_location="cpu"),
79
+ strict=False,
80
+ )
81
+ reference_unet.load_state_dict(
82
+ torch.load(config.reference_unet_path, map_location="cpu"),
83
+ )
84
+ pose_guider.load_state_dict(
85
+ torch.load(config.pose_guider_path, map_location="cpu"),
86
+ )
87
+
88
+ pipe = Pose2VideoPipeline(
89
+ vae=vae,
90
+ image_encoder=image_enc,
91
+ reference_unet=reference_unet,
92
+ denoising_unet=denoising_unet,
93
+ pose_guider=pose_guider,
94
+ scheduler=scheduler,
95
+ )
96
+ pipe = pipe.to("cuda", dtype=weight_dtype)