show / SHOW /stage2_main.py
camenduru's picture
thanks to show ❤
3bbb319
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright2023 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Authors: paper author.
# Special Acknowlegement: Wojciech Zielonka and Justus Thies
# Contact: [email protected]
from pathlib import Path
import numpy as np
from tqdm import tqdm
import mmcv
import numpy as np
import os
from tqdm import tqdm
import cv2
import os.path
from functools import reduce
from pathlib import Path
from loguru import logger
import face_alignment
import mmcv
from pathlib import Path
import numpy as np
from tqdm import tqdm
import mmcv
import numpy as np
import os
import os.path as osp
from tqdm import tqdm
import cv2
import glob
import os.path
from functools import reduce
from pathlib import Path
from loguru import logger
import face_alignment
import mmcv
from SHOW.utils.video import images_to_video
from torchvision.transforms.functional import gaussian_blur
from pytorch3d.transforms import axis_angle_to_matrix
from pytorch3d.renderer import RasterizationSettings, PointLights, MeshRenderer, MeshRasterizer, TexturesVertex, SoftPhongShader, look_at_view_transform, PerspectiveCameras
from pytorch3d.transforms import axis_angle_to_matrix
from pytorch3d.io import load_obj
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.backends.cudnn
import torch.nn.functional as F
import SHOW
from SHOW.utils import default_timers
from SHOW.datasets import op_base
from SHOW.detector.face_detector import FaceDetector
from SHOW.load_models import load_smplx_model, load_vposer_model
from SHOW.save_results import save_one_results
from SHOW.load_models import load_save_pkl
from SHOW.flame.FLAME import FLAMETex
from SHOW.smplx_dataset import ImagesDataset
from SHOW.renderer import Renderer
from SHOW.load_assets import load_assets
from SHOW.loggers.logger import setup_logger
from SHOW.save_tracker import save_tracker
from SHOW.utils import is_valid_json
from configs.cfg_ins import condor_cfg
@logger.catch
def SHOW_stage2(*args, **kwargs):
machine_info = SHOW.get_machine_info()
import pprint
pprint.pprint(f'machine_info: {machine_info}')
loggers = kwargs.get('loggers', None)
tracker_cfg = SHOW.from_rela_path(__file__,
'./configs/mmcv_tracker_config.py')
tracker_cfg.update(**kwargs)
tracker_cfg.merge_from_dict(condor_cfg)
if tracker_cfg.get('over_write_cfg', None):
tracker_cfg.update(tracker_cfg.over_write_cfg)
mmcv.dump(tracker_cfg, tracker_cfg.tracker_cfg_path)
try:
gpu_mem = machine_info['gpu_info']['gpu_Total']
import platform
if platform.system() == 'Linux':
# 50.0 * 24220 / (65.0*1024)
tracker_cfg.bs_at_a_time = int(50.0 * gpu_mem / (80.0 * 1024))
logger.warning(f'bs_at_a_time: {tracker_cfg.bs_at_a_time}')
except:
import traceback
traceback.print_exc()
Path(tracker_cfg.mica_save_path).mkdir(exist_ok=True, parents=True)
Path(tracker_cfg.mica_org_out_path).mkdir(exist_ok=True, parents=True)
iters = tracker_cfg.iters
sampling = tracker_cfg.sampling
device = tracker_cfg.device
tracker_cfg.dtype = dtype = SHOW.str_to_torch_dtype(tracker_cfg.dtype)
face_ider = SHOW.build_ider(tracker_cfg.ider_cfg)
img_folder = tracker_cfg.img_folder
template_im = os.listdir(img_folder)[0]
template_im = os.path.join(img_folder, template_im)
assets = load_assets(
tracker_cfg,
face_ider=face_ider,
template_im=template_im,
)
if assets is None:
return
setup_logger(tracker_cfg.mica_all_dir, filename='mica.log', mode='o')
if not Path(tracker_cfg.ours_pkl_file_path).exists():
logger.warning(
f'ours_pkl_file_path not exists: {tracker_cfg.ours_pkl_file_path}')
return False
if not is_valid_json(tracker_cfg.final_losses_json_path):
logger.warning(
f'final_losses_json_path not valid: {tracker_cfg.final_losses_json_path}'
)
return False
with default_timers['build_vars_stage']:
face_ider = SHOW.build_ider(tracker_cfg.ider_cfg)
person_face_emb = assets.person_face_emb
face_detector_mediapipe = FaceDetector('google', device=device)
face_detector = face_alignment.FaceAlignment(
face_alignment.LandmarksType._2D, device=device)
body_model = load_smplx_model(dtype=dtype, **tracker_cfg.smplx_cfg)
body_params_dict = load_save_pkl(tracker_cfg.ours_pkl_file_path,
device)
width = body_params_dict['width']
height = body_params_dict['height']
center = body_params_dict['center']
camera_transl = body_params_dict['camera_transl']
focal_length = body_params_dict['focal_length']
total_batch_size = body_params_dict['batch_size']
opt_bs = tracker_cfg.bs_at_a_time
opt_iters = total_batch_size // opt_bs
st_et_list = []
for i in range(opt_iters):
st = i * opt_bs
et = (i + 1) * opt_bs
if et > total_batch_size - 1:
et = total_batch_size - 1
st_et_list.append((st, et))
op = op_base()
smplx2flame_idx = assets.smplx2flame_idx
mesh_file = Path(__file__).parent.joinpath(
'../data/head_template_mesh.obj')
diff_renderer = Renderer(torch.Tensor([[512, 512]]),
obj_filename=mesh_file)
flame_faces = load_obj(mesh_file)[1]
flametex = FLAMETex(tracker_cfg.flame_cfg).to(device)
mesh_rasterizer = MeshRasterizer(
raster_settings=RasterizationSettings(image_size=[512, 512],
faces_per_pixel=1,
cull_backfaces=True,
perspective_correct=True))
debug_renderer = MeshRenderer(
rasterizer=mesh_rasterizer,
shader=SoftPhongShader(device=device,
lights=PointLights(
device=device,
location=((0.0, 0.0, -5.0), ),
ambient_color=((0.5, 0.5, 0.5), ),
diffuse_color=((0.5, 0.5, 0.5), ))))
pre_frame_exp = None
for opt_idx, (start_frame, end_frame) in enumerate(st_et_list):
if assets.person_face_emb is not None:
mica_part_file_path = f'w_mica_part_{start_frame}_{end_frame}_{opt_idx}_{opt_iters}.pkl'
mica_part_pkl_path = os.path.join(tracker_cfg.mica_all_dir,
mica_part_file_path)
if Path(mica_part_pkl_path).exists():
logger.info(
f'mica_part_pkl_path exists,skipping: {mica_part_pkl_path}'
)
pre_con = mmcv.load(mica_part_pkl_path)
pre_frame_exp = pre_con['expression'][-1]
pre_frame_exp = torch.Tensor(pre_frame_exp).to(device)
continue
opt_bs = end_frame - start_frame
com_tex = torch.zeros(1, 150).to('cuda')
com_sh = torch.zeros(1, 9, 3).to('cuda')
use_shared_tex = 1
if not use_shared_tex:
opt_bs_tex = nn.Parameter(com_tex).expand(opt_bs, -1).detach()
else:
opt_bs_tex = nn.Parameter(com_tex).expand(1, -1).detach()
opt_bs_sh = nn.Parameter(com_sh).expand(opt_bs, -1, -1).detach()
logger.info(f'origin input data frame batchsize:{opt_bs}')
with default_timers['load_dataset_stage']:
debug = 0
if debug: opt_bs = 30
dataset = ImagesDataset(
tracker_cfg,
start_frame=start_frame,
face_ider=face_ider,
person_face_emb=person_face_emb,
face_detector_mediapipe=face_detector_mediapipe,
face_detector=face_detector)
dataloader = DataLoader(dataset,
batch_size=opt_bs,
num_workers=0,
shuffle=False,
pin_memory=True,
drop_last=False)
iterator = iter(dataloader)
batch = next(iterator)
if not debug:
batch = SHOW.utils.to_cuda(batch)
valid_bool = batch['is_person_deted'].bool()
valid_bs = valid_bool.count_nonzero()
logger.info(f'valid input data frame batchsize:{valid_bs}')
logger.info(f'valid_bool: {valid_bool}')
if valid_bs == 0:
logger.warning('valid bs == 0, skipping')
open(mica_part_pkl_path + '.empty', 'a').close()
continue
bbox = batch['bbox']
images = batch['cropped_image']
landmarks = batch['cropped_lmk']
h = batch['h']
w = batch['w']
py = batch['py']
px = batch['px']
diff_renderer.masking.set_bs(valid_bs)
diff_renderer = diff_renderer.to(device)
debug = 0
report_wandb = 0
use_opt_pose = 1
save_traing_img = 0
observe_idx_list = [4, 8]
with default_timers['optimize_stage']:
model_output = None
def get_pose_opt(start_frame, end_frame):
tmp = body_params_dict['body_pose_axis'][
start_frame:end_frame, ...].clone().detach()
tmp = tmp.reshape(tmp.shape[0], -1, 3)
return torch.stack([tmp[:, 12 - 1, :],
tmp[:, 15 - 1, :]],
dim=1)
def clone_params_color(start_frame, end_frame):
opt_var_clone_detach = [
{
'params': [
nn.Parameter(
body_params_dict['expression']
[start_frame:end_frame].clone().detach())
],
'lr':
0.025,
'name': ['exp']
},
{
'params': [
nn.Parameter(body_params_dict['leye_pose']
[start_frame:end_frame].clone(
).clone().detach())
],
'lr':
0.001,
'name': ['leyes']
},
{
'params': [
nn.Parameter(
body_params_dict['reye_pose']
[start_frame:end_frame].clone().detach())
],
'lr':
0.001,
'name': ['reyes']
},
{
'params': [
nn.Parameter(
body_params_dict['jaw_pose']
[start_frame:end_frame].clone().detach())
],
'lr':
0.001,
'name': ['jaw']
},
{
'params':
[nn.Parameter(opt_bs_sh.clone().detach())],
'lr': 0.01,
'name': ['sh']
},
{
'params':
[nn.Parameter(opt_bs_tex.clone().detach())],
'lr': 0.005,
'name': ['tex']
},
]
if use_opt_pose:
opt_var_clone_detach.append({
'params': [
nn.Parameter(
get_pose_opt(start_frame, end_frame))
],
'lr':
0.005,
'name': ['body_pose']
})
return opt_var_clone_detach
save_traing_img_dir = tracker_cfg.mica_process_path + f'_{start_frame}_{end_frame}'
if save_traing_img:
Path(save_traing_img_dir).mkdir(parents=True,
exist_ok=True)
with tqdm(total=iters * 3,
position=0,
leave=True,
bar_format="{percentage:3.0f}%|{bar}{r_bar}{desc}"
) as pbar:
for k, scale in enumerate(sampling):
size = [int(512 * scale), int(512 * scale)]
img = F.interpolate(images.float().clone(),
size,
mode='bilinear',
align_corners=False)
if k > 0:
img = gaussian_blur(img, [9, 9]).detach()
flipped = torch.flip(img, [2, 3])
flipped = flipped[valid_bool.bool(), ...]
best_loss = np.inf
prev_loss = np.inf
xb_min, xb_max, yb_min, yb_max = bbox.values()
box_w = xb_max - xb_min
box_h = yb_max - yb_min
box_w = box_w.int()
box_h = box_h.int()
image_size = size
diff_renderer.rasterizer.reset()
diff_renderer.set_size(image_size)
debug_renderer.rasterizer.raster_settings.image_size = size
image_lmks = landmarks * size[0] / 512
image_lmks = image_lmks[valid_bool.bool(), ...]
optimizer = torch.optim.Adam(
clone_params_color(start_frame, end_frame))
params = optimizer.param_groups
get_param = SHOW.utils.get_param
cur_tex = get_param('tex', params)
cur_sh = get_param('sh', params)
cur_exp = get_param('exp', params)
cur_leyes = get_param('leyes', params)
cur_reyes = get_param('reyes', params)
cur_jaw = get_param('jaw', params)
if use_opt_pose:
two_opt = get_param('body_pose', params)
frame_pose = body_params_dict['body_pose_axis'][
start_frame:end_frame]
bs = frame_pose.shape[0]
frame_pose = frame_pose.reshape(bs, -1, 3)
cur_pose = torch.cat(
[
frame_pose[:, :11, :],
two_opt[:, 0:1], #11
frame_pose[:, 12:14, :],
two_opt[:, 1:2], #14
frame_pose[:, 15:, :]
],
dim=1).reshape(bs, 1, -1)
else:
frame_pose = body_params_dict['body_pose_axis'][
start_frame:end_frame]
bs = frame_pose.shape[0]
cur_pose = frame_pose.reshape(bs, 1, -1)
cur_transl = body_params_dict['transl'][
start_frame:end_frame]
cur_global_orient = body_params_dict['global_orient'][
start_frame:end_frame]
cur_left_hand_pose = body_params_dict[
'left_hand_pose'][start_frame:end_frame]
cur_right_hand_pose = body_params_dict[
'right_hand_pose'][start_frame:end_frame]
R = torch.Tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]]])
bs_image_size = torch.Tensor(image_size).repeat(
opt_bs, 1).to(device)
bs_camera_transl = (camera_transl).repeat(opt_bs,
1).to(device)
# bs_center: torch.Size([10, 2])
bs_center = torch.Tensor(center).repeat(opt_bs,
1).to(device)
bs_box_min = torch.stack([xb_min, yb_min],
dim=-1).to(device)
bs_R = R.repeat(opt_bs, 1, 1).to(device)
s_w = size[0] / box_w
s_h = size[1] / box_h
s_1to2 = torch.stack([s_w, s_h], dim=1)
s_1to2 = s_1to2.to(device)
bs_pp = (bs_center - bs_box_min) * (s_1to2)
bs_pp[:, 0] = (bs_pp[:, 0] + px) * 512 / (w + 2 * px)
bs_pp[:, 1] = (bs_pp[:, 1] + py) * 512 / (h + 2 * py)
s_1to2[:, 0] = s_1to2[:, 0] * 512 / (w + 2 * px)
s_1to2[:, 1] = s_1to2[:, 1] * 512 / (h + 2 * py)
cam_cfg = dict(
principal_point=bs_pp,
focal_length=focal_length * s_1to2,
R=bs_R,
T=bs_camera_transl,
image_size=bs_image_size,
)
for key, val in cam_cfg.items():
cam_cfg[key] = val[valid_bool.bool(), ...]
cameras = PerspectiveCameras(**cam_cfg,
device=device,
in_ndc=False)
if True:
for p in range(iters): # p=0
if (p + 1) % 32 == 0:
diff_renderer.rasterizer.reset()
losses = {}
model_output = body_model(
return_verts=True,
jaw_pose=cur_jaw,
leye_pose=cur_leyes,
reye_pose=cur_reyes,
expression=cur_exp,
betas=body_params_dict['betas'],
transl=cur_transl,
body_pose=cur_pose,
global_orient=cur_global_orient,
left_hand_pose=cur_left_hand_pose,
right_hand_pose=cur_right_hand_pose,
)
vertices = model_output.vertices[:,
smplx2flame_idx
.long(), :]
vertices = vertices[valid_bool.bool(), ...]
lmk68_all = model_output.joints[:, 67:67 + 51 +
17, :]
lmk68 = lmk68_all[valid_bool.bool(), ...]
proj_lmks = cameras.transform_points_screen(
lmk68)[:, :, :2]
proj_lmks = torch.cat([
proj_lmks[:, -17:, :],
proj_lmks[:, :-17, :]
],
dim=1)
I = torch.eye(3)[None].to(device)
if pre_frame_exp is not None and start_frame != 0:
losses['pre_exp'] = 0.001 * torch.sum(
(pre_frame_exp - cur_exp[0])**2)
if False:
linear_rot_left = (axis_angle_to_matrix(
cur_leyes[valid_bool.bool(), ...]))
linear_rot_right = (axis_angle_to_matrix(
cur_reyes[valid_bool.bool(), ...]))
losses['eyes_sym_reg'] = torch.sum(
(linear_rot_right - linear_rot_left)**
2) / opt_bs
losses['eyes_left_reg'] = torch.sum(
(I - linear_rot_left)**2) / opt_bs
losses['eyes_right_reg'] = torch.sum(
(I - linear_rot_right)**2) / opt_bs
w_lmks = tracker_cfg.w_lmks
losses['lmk'] = SHOW.utils.lmk_loss(
proj_lmks, image_lmks,
image_size) * w_lmks * 8.0
losses[
'lmk_mount'] = SHOW.utils.mouth_loss( #(49, 68)
proj_lmks, image_lmks,
image_size) * w_lmks * 4.0 * 4
losses['lmk_oval'] = SHOW.utils.lmk_loss(
proj_lmks[:, :17, ...], image_lmks[:, :17,
...],
image_size) * w_lmks
losses['jaw_reg'] = torch.sum(
(I - axis_angle_to_matrix(
cur_jaw[valid_bool.bool(), ...]))**
2) * 16.0 / opt_bs
losses['exp_reg'] = torch.sum(
cur_exp[valid_bool.bool(),
...]**2) * 0.01 / opt_bs
if use_shared_tex:
losses['tex_reg'] = torch.sum(cur_tex**
2) * 0.02
else:
losses['tex_reg'] = torch.sum(
cur_tex[valid_bool.bool(),
...]**2) * 0.02 / opt_bs
def temporary_loss(o_w, i_w, gmof, param):
assert param.shape[
0] > 2, f'optimize batchsize must > 2 to enable temporary smooth'
return (o_w**2) * (gmof(
i_w *
(param[2:, ...] + param[:-2, ...] -
2 * param[1:-1, ...]))).mean()
def pow(x):
return x.pow(2)
if cur_exp.shape[0] > 2:
losses['loss_sexp'] = temporary_loss(
1.0, 2.0, pow, cur_exp)
losses['loss_sjaw'] = temporary_loss(
1.0, 2.0, pow, cur_jaw)
def k_fun(k):
return tracker_cfg.w_pho * 32.0 if k > 0 else tracker_cfg.w_pho
albedos = flametex(cur_tex) / 255.
if use_shared_tex:
albedos = albedos.expand(
valid_bs, -1, -1, -1)
else:
albedos = albedos[valid_bool.bool(), ...]
ops = diff_renderer(
vertices, albedos,
cur_sh[valid_bool.bool(), ...], cameras)
grid = ops['position_images'].permute(
0, 2, 3, 1)[:, :, :, :2]
sampled_image = F.grid_sample(
flipped, grid, align_corners=False)
ops_mask = SHOW.utils.parse_mask(ops)
tmp_img = ops['images']
losses['pho'] = SHOW.utils.pixel_loss(
tmp_img, sampled_image,
ops_mask) * k_fun(k)
all_loss = 0.
for key in losses.keys():
all_loss = all_loss + losses[key]
losses['all_loss'] = all_loss
log_str = SHOW.print_dict_losses(losses)
if report_wandb:
if globals().get('wandb', None) is None:
os.environ[
'WANDB_API_KEY'] = 'xxx'
os.environ['WANDB_NAME'] = 'tracker'
import wandb
wandb.init(
reinit=True,
resume='allow',
project='tracker',
)
globals()['wandb'] = wandb
if globals().get('wandb',
None) is not None:
globals()['wandb'].log(losses)
if save_traing_img:
def save_callback(frame, final_views):
cur_idx = (frame + opt_bs * opt_idx)
if cur_idx in observe_idx_list:
observe_idx_frame_dir = os.path.join(
save_traing_img_dir,
f'{cur_idx:03d}')
Path(observe_idx_frame_dir).mkdir(
parents=True, exist_ok=True)
cv2.imwrite(
os.path.join(
observe_idx_frame_dir,
f'{k}_{p}.jpg'),
final_views)
save_tracker(
img,
valid_bool,
valid_bs,
ops,
vertices,
cameras,
image_lmks,
proj_lmks,
flame_faces,
mesh_rasterizer,
debug_renderer,
save_callback,
)
if loggers is not None:
loggers.log_bs(losses)
if torch.isnan(all_loss).sum():
loggers.alert(
title='Nan error',
msg=
f'tracker nan in: {tracker_cfg.ours_output_folder}'
)
open(
tracker_cfg.ours_output_folder +
'/mica_opt_nan.info', 'a').close()
break
else:
pbar.set_description(log_str)
pbar.update(1)
optimizer.zero_grad()
all_loss.backward()
optimizer.step()
if all_loss.item() < best_loss:
best_loss = all_loss.item()
opt_bs_tex = cur_tex.clone().detach()
opt_bs_sh = cur_sh.clone().detach()
body_params_dict['expression'][
start_frame:end_frame] = cur_exp.clone(
).detach()
body_params_dict['leye_pose'][
start_frame:
end_frame] = cur_leyes.clone().detach(
)
body_params_dict['reye_pose'][
start_frame:
end_frame] = cur_reyes.clone().detach(
)
body_params_dict['jaw_pose'][
start_frame:end_frame] = cur_jaw.clone(
).detach()
body_params_dict['body_pose_axis'][
start_frame:
end_frame] = cur_pose.clone().detach(
).squeeze()
with default_timers['saving_stage']:
if save_traing_img:
for idx in observe_idx_list:
observe_idx_frame_dir = os.path.join(
save_traing_img_dir, f'{idx:03d}')
Path(observe_idx_frame_dir).mkdir(parents=True,
exist_ok=True)
if not SHOW.is_empty_dir(observe_idx_frame_dir):
images_to_video(
input_folder=observe_idx_frame_dir,
output_path=observe_idx_frame_dir + '.mp4',
img_format=None,
fps=30,
)
dict_to_save = dict(
expression=body_params_dict['expression']
[start_frame:end_frame].clone().detach().cpu().numpy(),
leye_pose=body_params_dict['leye_pose']
[start_frame:end_frame].clone().detach().cpu().numpy(),
reye_pose=body_params_dict['reye_pose']
[start_frame:end_frame].clone().detach().cpu().numpy(),
jaw_pose=body_params_dict['jaw_pose']
[start_frame:end_frame].clone().detach().cpu().numpy(),
body_pose_axis=body_params_dict['body_pose_axis']
[start_frame:end_frame].clone().detach().cpu().numpy(),
tex=opt_bs_tex.clone().detach().cpu().numpy(),
sh=opt_bs_sh.clone().detach().cpu().numpy(),
)
mmcv.dump(dict_to_save, mica_part_pkl_path)
logger.info(f'mica pkl part path: {mica_part_pkl_path}')
pre_frame_exp = dict_to_save['expression'][-1]
pre_frame_exp = torch.Tensor(pre_frame_exp).to(device)
vertices_ = model_output.vertices.clone().detach().cpu()
logger.info(
f'mica render to origin path: {tracker_cfg.mica_org_out_path}')
import platform
if platform.system() == "Linux":
os.environ['PYOPENGL_PLATFORM'] = 'egl'
else:
if 'PYOPENGL_PLATFORM' in os.environ:
os.environ.__delitem__('PYOPENGL_PLATFORM')
import pyrender
input_renderer = pyrender.OffscreenRenderer(viewport_width=width,
viewport_height=height,
point_size=1.0)
for idx in tqdm(range(vertices_.shape[0]),
desc='saving ours final pyrender images'): # idx=0
cur_idx = idx + start_frame + 1
input_img = SHOW.find_full_impath_by_name(
root=tracker_cfg.img_folder, name=f'{cur_idx:06d}')
output_name = os.path.join(
tracker_cfg.mica_org_out_path,
f"{cur_idx:06}.{tracker_cfg.output_img_ext}")
camera_pose = op.get_smplx_to_pyrender_K(camera_transl)
meta_data = dict(
input_img=input_img,
output_name=output_name,
)
save_one_results(
vertices_[idx],
body_model.faces,
img_size=(height, width),
center=center,
focal_length=[focal_length, focal_length],
camera_pose=camera_pose,
meta_data=meta_data,
input_renderer=input_renderer,
)
input_renderer.delete()
if tracker_cfg.save_final_vis:
def save_callback(frame, final_views):
cur_idx = (frame + opt_bs * opt_idx)
if loggers is not None:
loggers.log_image(f"final_mica_img/{cur_idx:03d}",
final_views / 255.0)
cv2.imwrite(
os.path.join(tracker_cfg.mica_save_path,
f'{cur_idx:03d}.jpg'), final_views)
if True:
save_tracker(
img,
valid_bool,
valid_bs,
ops,
vertices,
cameras,
image_lmks,
proj_lmks,
flame_faces,
mesh_rasterizer,
debug_renderer,
save_callback,
)
load_data = mmcv.load(tracker_cfg.ours_pkl_file_path)[0]
load_data = SHOW.replace_mica_exp(tracker_cfg.mica_all_dir, load_data)
mmcv.dump([load_data], tracker_cfg.mica_merge_pkl)
if not Path(tracker_cfg.mica_org_out_video).exists():
if not SHOW.is_empty_dir(tracker_cfg.mica_org_out_path):
images_to_video(
input_folder=tracker_cfg.mica_org_out_path,
output_path=tracker_cfg.mica_org_out_video,
img_format=None,
fps=30,
)
if not Path(tracker_cfg.mica_grid_video).exists():
if not SHOW.is_empty_dir(tracker_cfg.mica_save_path):
images_to_video(
input_folder=tracker_cfg.mica_save_path,
output_path=tracker_cfg.mica_grid_video,
img_format=None,
fps=30,
)