X-Portrait / core /test_xportrait.py
fffiloni's picture
Migrated from GitHub
bfed184 verified
# *************************************************************************
# This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
# difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
# ytedance Inc..
# *************************************************************************
import os
import argparse
import numpy as np
# torch
import torch
from ema_pytorch import EMA
from einops import rearrange
import cv2
# utils
from utils.utils import set_seed, count_param, print_peak_memory
# model
import imageio
from model_lib.ControlNet.cldm.model import create_model
import copy
import glob
import imageio
from skimage.transform import resize
from skimage import img_as_ubyte
import face_alignment
import sys
from decord import VideoReader
from decord import cpu, gpu
TORCH_VERSION = torch.__version__.split(".")[0]
FP16_DTYPE = torch.float16
print(f"TORCH_VERSION={TORCH_VERSION} FP16_DTYPE={FP16_DTYPE}")
def extract_local_feature_from_single_img(img, fa, remove_local=False, real_tocrop=None, target_res = 512):
device = img.device
pred = img.permute([1, 2, 0]).detach().cpu().numpy()
pred_lmks = img_as_ubyte(resize(pred, (256, 256)))
try:
lmks = fa.get_landmarks_from_image(pred_lmks, return_landmark_score=False)[0]
except:
print ('undetected faces!!')
if real_tocrop is None:
return torch.zeros_like(img) * 2 - 1., [196,196,320,320]
return torch.zeros_like(img), [196,196,320,320]
halfedge = 32
left_eye_center = (np.clip(np.round(np.mean(lmks[43:48], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
right_eye_center = (np.clip(np.round(np.mean(lmks[37:42], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
mouth_center = (np.clip(np.round(np.mean(lmks[49:68], axis=0)), halfedge, 255-halfedge) * (target_res / 256)).astype(np.int32)
if real_tocrop is not None:
pred = real_tocrop.permute([1, 2, 0]).detach().cpu().numpy()
half_size = target_res // 8 #64
if remove_local:
local_viz = pred
local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = 0
local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = 0
local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size] = 0
else:
local_viz = np.zeros_like(pred)
local_viz[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size] = pred[left_eye_center[1] - half_size : left_eye_center[1] + half_size, left_eye_center[0] - half_size : left_eye_center[0] + half_size]
local_viz[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size] = pred[right_eye_center[1] - half_size : right_eye_center[1] + half_size, right_eye_center[0] - half_size : right_eye_center[0] + half_size]
local_viz[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size] = pred[mouth_center[1] - half_size : mouth_center[1] + half_size, mouth_center[0] - half_size : mouth_center[0] + half_size]
local_viz = torch.from_numpy(local_viz).to(device)
local_viz = local_viz.permute([2, 0, 1])
if real_tocrop is None:
local_viz = local_viz * 2 - 1.
return local_viz
def find_best_frame_byheadpose_fa(source_image, driving_video, fa):
input = img_as_ubyte(resize(source_image, (256, 256)))
try:
src_pose_array = fa.get_landmarks_from_image(input, return_landmark_score=False)[0]
except:
print ('undetected faces in the source image!!')
src_pose_array = np.zeros((68,2))
if len(src_pose_array) == 0:
return 0
min_diff = 1e8
best_frame = 0
for i in range(len(driving_video)):
frame = img_as_ubyte(resize(driving_video[i], (256, 256)))
try:
drv_pose_array = fa.get_landmarks_from_image(frame, return_landmark_score=False)[0]
except:
print ('undetected faces in the %d-th driving image!!'%i)
drv_pose_array = np.zeros((68,2))
diff = np.sum(np.abs(np.array(src_pose_array)-np.array(drv_pose_array)))
if diff < min_diff:
best_frame = i
min_diff = diff
return best_frame
def adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame=-1):
if best_frame == -2:
return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video]
src = img_as_ubyte(resize(source_image[..., :3], (256, 256)))
if best_frame >= len(source_image):
raise ValueError(
f"please specify one frame in driving video of which the pose match best with the pose of source image"
)
if best_frame < 0:
best_frame = find_best_frame_byheadpose_fa(src, driving_video, fa)
print ('Best Frame: %d' % best_frame)
driving = img_as_ubyte(resize(driving_video[best_frame], (256, 256)))
src_lmks = fa.get_landmarks_from_image(src, return_landmark_score=False)
drv_lmks = fa.get_landmarks_from_image(driving, return_landmark_score=False)
if (src_lmks is None) or (drv_lmks is None):
return [resize(frame, (nm_res, nm_res)) for frame in driving_video], [resize(frame, (nmd_res, nmd_res)) for frame in driving_video]
src_lmks = src_lmks[0]
drv_lmks = drv_lmks[0]
src_centers = np.mean(src_lmks, axis=0)
drv_centers = np.mean(drv_lmks, axis=0)
edge_src = (np.max(src_lmks, axis=0) - np.min(src_lmks, axis=0))*0.5
edge_drv = (np.max(drv_lmks, axis=0) - np.min(drv_lmks, axis=0))*0.5
#matching three points
src_point=np.array([[src_centers[0]-edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]-edge_src[1]],[src_centers[0]-edge_src[0],src_centers[1]+edge_src[1]],[src_centers[0]+edge_src[0],src_centers[1]+edge_src[1]]]).astype(np.float32)
dst_point=np.array([[drv_centers[0]-edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]-edge_drv[1]],[drv_centers[0]-edge_drv[0],drv_centers[1]+edge_drv[1]],[drv_centers[0]+edge_drv[0],drv_centers[1]+edge_drv[1]]]).astype(np.float32)
adjusted_driving_video = []
adjusted_driving_video_hd = []
for frame in driving_video:
frame_ld = resize(frame, (nm_res, nm_res))
frame_hd = resize(frame, (nmd_res, nmd_res))
zoomed=cv2.warpAffine(frame_ld, cv2.getAffineTransform(dst_point[:3], src_point[:3]), (nm_res, nm_res))
zoomed_hd=cv2.warpAffine(frame_hd, cv2.getAffineTransform(dst_point[:3] * 2, src_point[:3] * 2), (nmd_res, nmd_res))
adjusted_driving_video.append(zoomed)
adjusted_driving_video_hd.append(zoomed_hd)
return adjusted_driving_video, adjusted_driving_video_hd
def x_portrait_data_prep(source_image_path, driving_video_path, device, best_frame_id=0, start_idx = 0, num_frames=0, skip=1, output_local=False, more_source_image_pattern="", target_resolution = 512):
source_image = imageio.imread(source_image_path)
if '.mp4' in driving_video_path:
reader = imageio.get_reader(driving_video_path)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
else:
driving_video = [imageio.imread(driving_video_path)[...,:3]]
fps = 1
nmd_res = target_resolution
nm_res = 256
source_image_hd = resize(source_image, (nmd_res, nmd_res))[..., :3]
if more_source_image_pattern:
more_source_paths = glob.glob(more_source_image_pattern)
more_sources_hd = []
for more_source_path in more_source_paths:
more_source_image = imageio.imread(more_source_path)
more_source_image_hd = resize(more_source_image, (nmd_res, nmd_res))[..., :3]
more_source_hd = torch.tensor(more_source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
more_source_hd = more_source_hd.to(device)
more_sources_hd.append(more_source_hd)
more_sources_hd = torch.stack(more_sources_hd, dim = 1)
else:
more_sources_hd = None
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=True, device='cuda')
driving_video, driving_video_hd = adjust_driving_video_to_src_image(source_image, driving_video, fa, nm_res, nmd_res, best_frame_id)
if num_frames == 0:
end_idx = len(driving_video)
else:
num_frames = min(len(driving_video), num_frames)
end_idx = start_idx + num_frames * skip
driving_video = driving_video[start_idx:end_idx][::skip]
driving_video_hd = driving_video_hd[start_idx:end_idx][::skip]
num_frames = len(driving_video)
with torch.no_grad():
real_source_hd = torch.tensor(source_image_hd[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
real_source_hd = real_source_hd.to(device)
driving_hd = torch.tensor(np.array(driving_video_hd).astype(np.float32)).permute(0, 3, 1, 2).to(device)
local_features = []
raw_drivings=[]
for frame_idx in range(0, num_frames):
raw_drivings.append(driving_hd[frame_idx:frame_idx+1] * 2 - 1.)
if output_local:
local_feature_img = extract_local_feature_from_single_img(driving_hd[frame_idx], fa,target_res=nmd_res)
local_features.append(local_feature_img)
batch_data = {}
batch_data['fps'] = fps
real_source_hd = real_source_hd * 2 - 1
batch_data['sources'] = real_source_hd[:, None, :, :, :].repeat([num_frames, 1, 1, 1, 1])
if more_sources_hd is not None:
more_sources_hd = more_sources_hd * 2 - 1
batch_data['more_sources'] = more_sources_hd.repeat([num_frames, 1, 1, 1, 1])
raw_drivings = torch.stack(raw_drivings, dim = 0)
batch_data['conditions'] = raw_drivings
if output_local:
batch_data['local'] = torch.stack(local_features, dim = 0)
return batch_data
# You can now use the modified state_dict without the deleted keys
def load_state_dict(model, ckpt_path, reinit_hint_block=False, strict=True, map_location="cpu"):
print(f"Loading model state dict from {ckpt_path} ...")
state_dict = torch.load(ckpt_path, map_location=map_location)
state_dict = state_dict.get('state_dict', state_dict)
if reinit_hint_block:
print("Ignoring hint block parameters from checkpoint!")
for k in list(state_dict.keys()):
if k.startswith("control_model.input_hint_block"):
state_dict.pop(k)
model.load_state_dict(state_dict, strict=strict)
del state_dict
def get_cond_control(args, batch_data, control_type, device, start, end, model=None, batch_size=None, train=True, key=0):
control_type = copy.deepcopy(control_type)
vae_bs = 16
if control_type == "appearance_pose_local_mm":
src = batch_data['sources'][start:end, key].cuda()
c_cat_list = batch_data['conditions'][start:end].cuda()
cond_image = []
for k in range(0, end-start, vae_bs):
cond_image.append(model.get_first_stage_encoding(model.encode_first_stage(src[k:k+vae_bs])))
cond_image = torch.concat(cond_image, dim=0)
cond_img_cat = cond_image
p_local = batch_data['local'][start:end].cuda()
print ('Total frames:{}'.format(cond_img_cat.shape))
more_cond_imgs = []
if 'more_sources' in batch_data:
num_additional_cond_imgs = batch_data['more_sources'].shape[1]
for i in range(num_additional_cond_imgs):
m_cond_img = batch_data['more_sources'][start:end, i]
m_cond_img = model.get_first_stage_encoding(model.encode_first_stage(m_cond_img))
more_cond_imgs.append([m_cond_img.to(device)])
return [cond_img_cat.to(device), c_cat_list, p_local, more_cond_imgs]
else:
raise NotImplementedError(f"cond_type={control_type} not supported!")
def visualize_mm(args, name, batch_data, infer_model, nSample, local_image_dir, num_mix=4, preset_output_name=''):
driving_video_name = os.path.basename(batch_data['video_name']).split('.')[0]
source_name = os.path.basename(batch_data['source_name']).split('.')[0]
if not os.path.exists(local_image_dir):
os.mkdir(local_image_dir)
uc_scale = args.uc_scale
if preset_output_name:
preset_output_name = preset_output_name.split('.')[0]+'.mp4'
output_path = f"{local_image_dir}/{preset_output_name}"
else:
output_path = f"{local_image_dir}/{name}_{args.control_type}_uc{uc_scale}_{source_name}_by_{driving_video_name}_mix{num_mix}.mp4"
infer_model.eval()
gene_img_list = []
_, _, ch, h, w = batch_data['sources'].shape
vae_bs = 16
if args.initial_facevid2vid_results:
facevid2vid = []
facevid2vid_results = VideoReader(args.initial_facevid2vid_results, ctx=cpu(0))
for frame_id in range(len(facevid2vid_results)):
frame = cv2.resize(facevid2vid_results[frame_id].asnumpy(),(512,512)) / 255
facevid2vid.append(torch.from_numpy(frame * 2 - 1).permute(2,0,1))
cond = torch.stack(facevid2vid)[:nSample].float().to(args.device)
pre_noise=[]
for i in range(0, nSample, vae_bs):
pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs])))
pre_noise = torch.cat(pre_noise, dim=0)
pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device))
else:
cond = batch_data['sources'][:nSample].reshape([-1, ch, h, w])
pre_noise=[]
for i in range(0, nSample, vae_bs):
pre_noise.append(infer_model.get_first_stage_encoding(infer_model.encode_first_stage(cond[i:i+vae_bs])))
pre_noise = torch.cat(pre_noise, dim=0)
pre_noise = infer_model.q_sample(x_start = pre_noise, t = torch.tensor([999]).to(pre_noise.device))
text = ["" for _ in range(nSample)]
all_c_cat = get_cond_control(args, batch_data, args.control_type, args.device, start=0, end=nSample, model=infer_model, train=False)
cond_img_cat = [all_c_cat[0]]
pose_cond_list = [rearrange(all_c_cat[1], "b f c h w -> (b f) c h w")]
local_pose_cond_list = [all_c_cat[2]]
c_cross = infer_model.get_learned_conditioning(text)[:nSample]
uc_cross = infer_model.get_unconditional_conditioning(nSample)
c = {"c_crossattn": [c_cross], "image_control": cond_img_cat}
if "appearance_pose" in args.control_type:
c['c_concat'] = pose_cond_list
if "appearance_pose_local" in args.control_type:
c["local_c_concat"] = local_pose_cond_list
if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0:
c['more_image_control'] = all_c_cat[3]
if args.control_mode == "controlnet_important":
uc = {"c_crossattn": [uc_cross]}
else:
uc = {"c_crossattn": [uc_cross], "image_control":cond_img_cat}
if "appearance_pose" in args.control_type:
uc['c_concat'] = [torch.zeros_like(pose_cond_list[0])]
if "appearance_pose_local" in args.control_type:
uc["local_c_concat"] = [torch.zeros_like(local_pose_cond_list[0])]
if len(all_c_cat) > 3 and len(all_c_cat[3]) > 0:
uc['more_image_control'] = all_c_cat[3]
if args.wonoise:
c['wonoise'] = True
uc['wonoise'] = True
else:
c['wonoise'] = False
uc['wonoise'] = False
noise = pre_noise.to(c_cross.device)
with torch.cuda.amp.autocast(enabled=args.use_fp16, dtype=FP16_DTYPE):
infer_model.to(args.device)
infer_model.eval()
gene_img, _ = infer_model.sample_log(cond=c,
batch_size=args.num_drivings, ddim=True,
ddim_steps=args.ddim_steps, eta=args.eta,
unconditional_guidance_scale=uc_scale,
unconditional_conditioning=uc,
inpaint=None,
x_T=noise,
num_overlap=num_mix,
)
for i in range(0, nSample, vae_bs):
gene_img_part = infer_model.decode_first_stage( gene_img[i:i+vae_bs] )
gene_img_list.append(gene_img_part.float().clamp(-1, 1).cpu())
_, c, h, w = gene_img_list[0].shape
cond_image = batch_data["conditions"].reshape([-1,c,h,w])[:nSample].cpu()
l_cond_image = batch_data["local"].reshape([-1,c,h,w])[:nSample].cpu()
orig_image = batch_data["sources"][:nSample, 0].cpu()
output_img = torch.cat(gene_img_list + [cond_image.cpu()]+[l_cond_image.cpu()]+[orig_image.cpu()]).float().clamp(-1,1).add(1).mul(0.5)
num_cols = 4
output_img = output_img.reshape([num_cols, 1, nSample, c, h, w]).permute([1, 0, 2, 3, 4,5])
output_img = output_img.permute([2, 3, 0, 4, 1, 5]).reshape([-1, c, h, num_cols * w])
output_img = torch.permute(output_img, [0, 2, 3, 1])
output_img = output_img.data.cpu().numpy()
output_img = img_as_ubyte(output_img)
imageio.mimsave(output_path, output_img[:,:,:512], fps=batch_data['fps'], quality=10, pixelformat='yuv420p', codec='libx264')
def main(args):
# ******************************
# initialize training
# ******************************
args.world_size = 1
args.local_rank = 0
args.rank = 0
args.device = torch.device("cuda", args.local_rank)
# set seed for reproducibility
set_seed(args.seed)
# ******************************
# create model
# ******************************
model = create_model(args.model_config).cpu()
model.sd_locked = args.sd_locked
model.only_mid_control = args.only_mid_control
model.to(args.local_rank)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.local_rank == 0:
print('Total base parameters {:.02f}M'.format(count_param([model])))
if args.ema_rate is not None and args.ema_rate > 0 and args.rank == 0:
print(f"Creating EMA model at ema_rate={args.ema_rate}")
model_ema = EMA(model, beta=args.ema_rate, update_after_step=0, update_every=1)
else:
model_ema = None
# ******************************
# load pre-trained models
# ******************************
if args.resume_dir is not None:
if args.local_rank == 0:
load_state_dict(model, args.resume_dir, strict=False)
else:
print('please privide the correct resume_dir!')
exit()
# ******************************
# create DDP model
# ******************************
if args.compile and TORCH_VERSION == "2":
model = torch.compile(model)
torch.cuda.set_device(args.local_rank)
print_peak_memory("Max memory allocated after creating DDP", args.local_rank)
infer_model = model.module if hasattr(model, "module") else model
with torch.no_grad():
driving_videos = glob.glob(args.driving_video)
for driving_video in driving_videos:
print ('working on {}'.format(os.path.basename(driving_video)))
infer_batch_data = x_portrait_data_prep(args.source_image, driving_video, args.device, args.best_frame, start_idx = args.start_idx, num_frames = args.out_frames, skip=args.skip, output_local=True)
infer_batch_data['video_name'] = os.path.basename(driving_video)
infer_batch_data['source_name'] = args.source_image
nSample = infer_batch_data['sources'].shape[0]
visualize_mm(args, "inference", infer_batch_data, infer_model, nSample=nSample, local_image_dir=args.output_dir, num_mix=args.num_mix)
if __name__ == "__main__":
str2bool = lambda arg: bool(int(arg))
parser = argparse.ArgumentParser(description='Control Net training')
## Model
parser.add_argument('--model_config', type=str, default="model_lib/ControlNet/models/cldm_v15_video_appearance.yaml",
help="The path of model config file")
parser.add_argument('--reinit_hint_block', action='store_true', default=False,
help="Re-initialize hint blocks for channel mis-match")
parser.add_argument('--sd_locked', type =str2bool, default=True,
help='Freeze parameters in original stable-diffusion decoder')
parser.add_argument('--only_mid_control', type =str2bool, default=False,
help='Only control middle blocks')
parser.add_argument('--control_type', type=str, default="appearance_pose_local_mm",
help='The type of conditioning')
parser.add_argument("--control_mode", type=str, default="controlnet_important",
help="Set controlnet is more important or balance.")
parser.add_argument('--wonoise', action='store_false', default=True,
help='Use with referenceonly, remove adding noise on reference image')
## Training
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
parser.add_argument('--seed', type=int, default=42,
help='random seed for initialization')
parser.add_argument('--use_fp16', action='store_false', default=True,
help='Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit')
parser.add_argument('--compile', type=str2bool, default=False,
help='compile model (for torch 2)')
parser.add_argument('--eta', type = float, default = 0.0,
help='eta during DDIM Sampling')
parser.add_argument('--ema_rate', type = float, default = 0,
help='rate for ema')
## inference
parser.add_argument("--initial_facevid2vid_results", type=str, default=None,
help="facevid2vid results for noise initialization")
parser.add_argument('--ddim_steps', type = int, default = 1,
help='denoising steps')
parser.add_argument('--uc_scale', type = int, default = 5,
help='cfg')
parser.add_argument("--num_drivings", type = int, default = 16,
help="Number of driving images in a single sequence of video.")
parser.add_argument("--output_dir", type=str, default=None, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--resume_dir", type=str, default=None,
help="The resume directory where the model checkpoints will be loaded.")
parser.add_argument("--source_image", type=str, default="",
help="The source image for neural motion.")
parser.add_argument("--more_source_image_pattern", type=str, default="",
help="The source image for neural motion.")
parser.add_argument("--driving_video", type=str, default="",
help="The source image mask for neural motion.")
parser.add_argument('--best_frame', type=int, default=0,
help='best matching frame index')
parser.add_argument('--start_idx', type=int, default=0,
help='starting frame index')
parser.add_argument('--skip', type=int, default=1,
help='skip frame')
parser.add_argument('--num_mix', type=int, default=4,
help='num overlapping frames')
parser.add_argument('--out_frames', type=int, default=0,
help='num frames')
args = parser.parse_args()
main(args)