|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
import numpy as np |
|
from tqdm import tqdm |
|
import mmcv |
|
import numpy as np |
|
import os |
|
from tqdm import tqdm |
|
import cv2 |
|
import os.path |
|
from functools import reduce |
|
from pathlib import Path |
|
from loguru import logger |
|
import face_alignment |
|
import mmcv |
|
from pathlib import Path |
|
import numpy as np |
|
from tqdm import tqdm |
|
import mmcv |
|
import numpy as np |
|
import os |
|
import os.path as osp |
|
from tqdm import tqdm |
|
import cv2 |
|
import glob |
|
import os.path |
|
from functools import reduce |
|
from pathlib import Path |
|
from loguru import logger |
|
import face_alignment |
|
import mmcv |
|
|
|
from SHOW.utils.video import images_to_video |
|
from torchvision.transforms.functional import gaussian_blur |
|
from pytorch3d.transforms import axis_angle_to_matrix |
|
from pytorch3d.renderer import RasterizationSettings, PointLights, MeshRenderer, MeshRasterizer, TexturesVertex, SoftPhongShader, look_at_view_transform, PerspectiveCameras |
|
from pytorch3d.transforms import axis_angle_to_matrix |
|
from pytorch3d.io import load_obj |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
import torch.backends.cudnn |
|
import torch.nn.functional as F |
|
|
|
import SHOW |
|
from SHOW.utils import default_timers |
|
from SHOW.datasets import op_base |
|
from SHOW.detector.face_detector import FaceDetector |
|
from SHOW.load_models import load_smplx_model, load_vposer_model |
|
from SHOW.save_results import save_one_results |
|
from SHOW.load_models import load_save_pkl |
|
from SHOW.flame.FLAME import FLAMETex |
|
from SHOW.smplx_dataset import ImagesDataset |
|
from SHOW.renderer import Renderer |
|
from SHOW.load_assets import load_assets |
|
from SHOW.loggers.logger import setup_logger |
|
from SHOW.save_tracker import save_tracker |
|
from SHOW.utils import is_valid_json |
|
from configs.cfg_ins import condor_cfg |
|
|
|
|
|
@logger.catch |
|
def SHOW_stage2(*args, **kwargs): |
|
|
|
machine_info = SHOW.get_machine_info() |
|
import pprint |
|
pprint.pprint(f'machine_info: {machine_info}') |
|
|
|
loggers = kwargs.get('loggers', None) |
|
|
|
tracker_cfg = SHOW.from_rela_path(__file__, |
|
'./configs/mmcv_tracker_config.py') |
|
tracker_cfg.update(**kwargs) |
|
tracker_cfg.merge_from_dict(condor_cfg) |
|
|
|
if tracker_cfg.get('over_write_cfg', None): |
|
tracker_cfg.update(tracker_cfg.over_write_cfg) |
|
|
|
|
|
mmcv.dump(tracker_cfg, tracker_cfg.tracker_cfg_path) |
|
|
|
|
|
try: |
|
gpu_mem = machine_info['gpu_info']['gpu_Total'] |
|
|
|
import platform |
|
if platform.system() == 'Linux': |
|
|
|
tracker_cfg.bs_at_a_time = int(50.0 * gpu_mem / (80.0 * 1024)) |
|
logger.warning(f'bs_at_a_time: {tracker_cfg.bs_at_a_time}') |
|
except: |
|
import traceback |
|
traceback.print_exc() |
|
|
|
Path(tracker_cfg.mica_save_path).mkdir(exist_ok=True, parents=True) |
|
Path(tracker_cfg.mica_org_out_path).mkdir(exist_ok=True, parents=True) |
|
|
|
iters = tracker_cfg.iters |
|
sampling = tracker_cfg.sampling |
|
device = tracker_cfg.device |
|
tracker_cfg.dtype = dtype = SHOW.str_to_torch_dtype(tracker_cfg.dtype) |
|
|
|
face_ider = SHOW.build_ider(tracker_cfg.ider_cfg) |
|
img_folder = tracker_cfg.img_folder |
|
template_im = os.listdir(img_folder)[0] |
|
template_im = os.path.join(img_folder, template_im) |
|
|
|
assets = load_assets( |
|
tracker_cfg, |
|
face_ider=face_ider, |
|
template_im=template_im, |
|
) |
|
if assets is None: |
|
return |
|
|
|
setup_logger(tracker_cfg.mica_all_dir, filename='mica.log', mode='o') |
|
|
|
|
|
if not Path(tracker_cfg.ours_pkl_file_path).exists(): |
|
logger.warning( |
|
f'ours_pkl_file_path not exists: {tracker_cfg.ours_pkl_file_path}') |
|
return False |
|
|
|
if not is_valid_json(tracker_cfg.final_losses_json_path): |
|
logger.warning( |
|
f'final_losses_json_path not valid: {tracker_cfg.final_losses_json_path}' |
|
) |
|
return False |
|
|
|
|
|
with default_timers['build_vars_stage']: |
|
face_ider = SHOW.build_ider(tracker_cfg.ider_cfg) |
|
person_face_emb = assets.person_face_emb |
|
face_detector_mediapipe = FaceDetector('google', device=device) |
|
face_detector = face_alignment.FaceAlignment( |
|
face_alignment.LandmarksType._2D, device=device) |
|
|
|
body_model = load_smplx_model(dtype=dtype, **tracker_cfg.smplx_cfg) |
|
body_params_dict = load_save_pkl(tracker_cfg.ours_pkl_file_path, |
|
device) |
|
|
|
width = body_params_dict['width'] |
|
height = body_params_dict['height'] |
|
center = body_params_dict['center'] |
|
camera_transl = body_params_dict['camera_transl'] |
|
focal_length = body_params_dict['focal_length'] |
|
total_batch_size = body_params_dict['batch_size'] |
|
|
|
opt_bs = tracker_cfg.bs_at_a_time |
|
opt_iters = total_batch_size // opt_bs |
|
st_et_list = [] |
|
for i in range(opt_iters): |
|
st = i * opt_bs |
|
et = (i + 1) * opt_bs |
|
if et > total_batch_size - 1: |
|
et = total_batch_size - 1 |
|
st_et_list.append((st, et)) |
|
|
|
op = op_base() |
|
smplx2flame_idx = assets.smplx2flame_idx |
|
|
|
mesh_file = Path(__file__).parent.joinpath( |
|
'../data/head_template_mesh.obj') |
|
|
|
diff_renderer = Renderer(torch.Tensor([[512, 512]]), |
|
obj_filename=mesh_file) |
|
|
|
flame_faces = load_obj(mesh_file)[1] |
|
flametex = FLAMETex(tracker_cfg.flame_cfg).to(device) |
|
|
|
mesh_rasterizer = MeshRasterizer( |
|
raster_settings=RasterizationSettings(image_size=[512, 512], |
|
faces_per_pixel=1, |
|
cull_backfaces=True, |
|
perspective_correct=True)) |
|
|
|
debug_renderer = MeshRenderer( |
|
rasterizer=mesh_rasterizer, |
|
shader=SoftPhongShader(device=device, |
|
lights=PointLights( |
|
device=device, |
|
location=((0.0, 0.0, -5.0), ), |
|
ambient_color=((0.5, 0.5, 0.5), ), |
|
diffuse_color=((0.5, 0.5, 0.5), )))) |
|
|
|
pre_frame_exp = None |
|
for opt_idx, (start_frame, end_frame) in enumerate(st_et_list): |
|
if assets.person_face_emb is not None: |
|
mica_part_file_path = f'w_mica_part_{start_frame}_{end_frame}_{opt_idx}_{opt_iters}.pkl' |
|
mica_part_pkl_path = os.path.join(tracker_cfg.mica_all_dir, |
|
mica_part_file_path) |
|
|
|
if Path(mica_part_pkl_path).exists(): |
|
logger.info( |
|
f'mica_part_pkl_path exists,skipping: {mica_part_pkl_path}' |
|
) |
|
pre_con = mmcv.load(mica_part_pkl_path) |
|
pre_frame_exp = pre_con['expression'][-1] |
|
pre_frame_exp = torch.Tensor(pre_frame_exp).to(device) |
|
continue |
|
|
|
opt_bs = end_frame - start_frame |
|
|
|
com_tex = torch.zeros(1, 150).to('cuda') |
|
com_sh = torch.zeros(1, 9, 3).to('cuda') |
|
|
|
use_shared_tex = 1 |
|
if not use_shared_tex: |
|
opt_bs_tex = nn.Parameter(com_tex).expand(opt_bs, -1).detach() |
|
else: |
|
opt_bs_tex = nn.Parameter(com_tex).expand(1, -1).detach() |
|
|
|
opt_bs_sh = nn.Parameter(com_sh).expand(opt_bs, -1, -1).detach() |
|
|
|
logger.info(f'origin input data frame batchsize:{opt_bs}') |
|
|
|
with default_timers['load_dataset_stage']: |
|
|
|
debug = 0 |
|
if debug: opt_bs = 30 |
|
|
|
dataset = ImagesDataset( |
|
tracker_cfg, |
|
start_frame=start_frame, |
|
face_ider=face_ider, |
|
person_face_emb=person_face_emb, |
|
face_detector_mediapipe=face_detector_mediapipe, |
|
face_detector=face_detector) |
|
dataloader = DataLoader(dataset, |
|
batch_size=opt_bs, |
|
num_workers=0, |
|
shuffle=False, |
|
pin_memory=True, |
|
drop_last=False) |
|
iterator = iter(dataloader) |
|
batch = next(iterator) |
|
if not debug: |
|
batch = SHOW.utils.to_cuda(batch) |
|
valid_bool = batch['is_person_deted'].bool() |
|
valid_bs = valid_bool.count_nonzero() |
|
logger.info(f'valid input data frame batchsize:{valid_bs}') |
|
logger.info(f'valid_bool: {valid_bool}') |
|
|
|
if valid_bs == 0: |
|
logger.warning('valid bs == 0, skipping') |
|
open(mica_part_pkl_path + '.empty', 'a').close() |
|
continue |
|
|
|
bbox = batch['bbox'] |
|
images = batch['cropped_image'] |
|
landmarks = batch['cropped_lmk'] |
|
h = batch['h'] |
|
w = batch['w'] |
|
py = batch['py'] |
|
px = batch['px'] |
|
|
|
diff_renderer.masking.set_bs(valid_bs) |
|
diff_renderer = diff_renderer.to(device) |
|
|
|
debug = 0 |
|
report_wandb = 0 |
|
use_opt_pose = 1 |
|
save_traing_img = 0 |
|
observe_idx_list = [4, 8] |
|
|
|
with default_timers['optimize_stage']: |
|
model_output = None |
|
|
|
def get_pose_opt(start_frame, end_frame): |
|
tmp = body_params_dict['body_pose_axis'][ |
|
start_frame:end_frame, ...].clone().detach() |
|
tmp = tmp.reshape(tmp.shape[0], -1, 3) |
|
return torch.stack([tmp[:, 12 - 1, :], |
|
tmp[:, 15 - 1, :]], |
|
dim=1) |
|
|
|
def clone_params_color(start_frame, end_frame): |
|
opt_var_clone_detach = [ |
|
{ |
|
'params': [ |
|
nn.Parameter( |
|
body_params_dict['expression'] |
|
[start_frame:end_frame].clone().detach()) |
|
], |
|
'lr': |
|
0.025, |
|
'name': ['exp'] |
|
}, |
|
{ |
|
'params': [ |
|
nn.Parameter(body_params_dict['leye_pose'] |
|
[start_frame:end_frame].clone( |
|
).clone().detach()) |
|
], |
|
'lr': |
|
0.001, |
|
'name': ['leyes'] |
|
}, |
|
{ |
|
'params': [ |
|
nn.Parameter( |
|
body_params_dict['reye_pose'] |
|
[start_frame:end_frame].clone().detach()) |
|
], |
|
'lr': |
|
0.001, |
|
'name': ['reyes'] |
|
}, |
|
{ |
|
'params': [ |
|
nn.Parameter( |
|
body_params_dict['jaw_pose'] |
|
[start_frame:end_frame].clone().detach()) |
|
], |
|
'lr': |
|
0.001, |
|
'name': ['jaw'] |
|
}, |
|
{ |
|
'params': |
|
[nn.Parameter(opt_bs_sh.clone().detach())], |
|
'lr': 0.01, |
|
'name': ['sh'] |
|
}, |
|
{ |
|
'params': |
|
[nn.Parameter(opt_bs_tex.clone().detach())], |
|
'lr': 0.005, |
|
'name': ['tex'] |
|
}, |
|
] |
|
if use_opt_pose: |
|
opt_var_clone_detach.append({ |
|
'params': [ |
|
nn.Parameter( |
|
get_pose_opt(start_frame, end_frame)) |
|
], |
|
'lr': |
|
0.005, |
|
'name': ['body_pose'] |
|
}) |
|
return opt_var_clone_detach |
|
|
|
save_traing_img_dir = tracker_cfg.mica_process_path + f'_{start_frame}_{end_frame}' |
|
|
|
if save_traing_img: |
|
Path(save_traing_img_dir).mkdir(parents=True, |
|
exist_ok=True) |
|
|
|
with tqdm(total=iters * 3, |
|
position=0, |
|
leave=True, |
|
bar_format="{percentage:3.0f}%|{bar}{r_bar}{desc}" |
|
) as pbar: |
|
for k, scale in enumerate(sampling): |
|
|
|
size = [int(512 * scale), int(512 * scale)] |
|
img = F.interpolate(images.float().clone(), |
|
size, |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
if k > 0: |
|
img = gaussian_blur(img, [9, 9]).detach() |
|
|
|
flipped = torch.flip(img, [2, 3]) |
|
flipped = flipped[valid_bool.bool(), ...] |
|
|
|
best_loss = np.inf |
|
prev_loss = np.inf |
|
|
|
xb_min, xb_max, yb_min, yb_max = bbox.values() |
|
box_w = xb_max - xb_min |
|
box_h = yb_max - yb_min |
|
box_w = box_w.int() |
|
box_h = box_h.int() |
|
|
|
image_size = size |
|
|
|
diff_renderer.rasterizer.reset() |
|
diff_renderer.set_size(image_size) |
|
debug_renderer.rasterizer.raster_settings.image_size = size |
|
|
|
image_lmks = landmarks * size[0] / 512 |
|
image_lmks = image_lmks[valid_bool.bool(), ...] |
|
|
|
optimizer = torch.optim.Adam( |
|
clone_params_color(start_frame, end_frame)) |
|
params = optimizer.param_groups |
|
get_param = SHOW.utils.get_param |
|
|
|
cur_tex = get_param('tex', params) |
|
cur_sh = get_param('sh', params) |
|
|
|
cur_exp = get_param('exp', params) |
|
cur_leyes = get_param('leyes', params) |
|
cur_reyes = get_param('reyes', params) |
|
cur_jaw = get_param('jaw', params) |
|
|
|
if use_opt_pose: |
|
two_opt = get_param('body_pose', params) |
|
frame_pose = body_params_dict['body_pose_axis'][ |
|
start_frame:end_frame] |
|
bs = frame_pose.shape[0] |
|
frame_pose = frame_pose.reshape(bs, -1, 3) |
|
cur_pose = torch.cat( |
|
[ |
|
frame_pose[:, :11, :], |
|
two_opt[:, 0:1], |
|
frame_pose[:, 12:14, :], |
|
two_opt[:, 1:2], |
|
frame_pose[:, 15:, :] |
|
], |
|
dim=1).reshape(bs, 1, -1) |
|
else: |
|
frame_pose = body_params_dict['body_pose_axis'][ |
|
start_frame:end_frame] |
|
bs = frame_pose.shape[0] |
|
cur_pose = frame_pose.reshape(bs, 1, -1) |
|
|
|
cur_transl = body_params_dict['transl'][ |
|
start_frame:end_frame] |
|
cur_global_orient = body_params_dict['global_orient'][ |
|
start_frame:end_frame] |
|
cur_left_hand_pose = body_params_dict[ |
|
'left_hand_pose'][start_frame:end_frame] |
|
cur_right_hand_pose = body_params_dict[ |
|
'right_hand_pose'][start_frame:end_frame] |
|
|
|
R = torch.Tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]) |
|
bs_image_size = torch.Tensor(image_size).repeat( |
|
opt_bs, 1).to(device) |
|
bs_camera_transl = (camera_transl).repeat(opt_bs, |
|
1).to(device) |
|
|
|
|
|
bs_center = torch.Tensor(center).repeat(opt_bs, |
|
1).to(device) |
|
bs_box_min = torch.stack([xb_min, yb_min], |
|
dim=-1).to(device) |
|
bs_R = R.repeat(opt_bs, 1, 1).to(device) |
|
|
|
s_w = size[0] / box_w |
|
s_h = size[1] / box_h |
|
s_1to2 = torch.stack([s_w, s_h], dim=1) |
|
s_1to2 = s_1to2.to(device) |
|
bs_pp = (bs_center - bs_box_min) * (s_1to2) |
|
|
|
bs_pp[:, 0] = (bs_pp[:, 0] + px) * 512 / (w + 2 * px) |
|
bs_pp[:, 1] = (bs_pp[:, 1] + py) * 512 / (h + 2 * py) |
|
|
|
s_1to2[:, 0] = s_1to2[:, 0] * 512 / (w + 2 * px) |
|
s_1to2[:, 1] = s_1to2[:, 1] * 512 / (h + 2 * py) |
|
|
|
cam_cfg = dict( |
|
principal_point=bs_pp, |
|
focal_length=focal_length * s_1to2, |
|
R=bs_R, |
|
T=bs_camera_transl, |
|
image_size=bs_image_size, |
|
) |
|
|
|
for key, val in cam_cfg.items(): |
|
cam_cfg[key] = val[valid_bool.bool(), ...] |
|
|
|
cameras = PerspectiveCameras(**cam_cfg, |
|
device=device, |
|
in_ndc=False) |
|
|
|
if True: |
|
for p in range(iters): |
|
if (p + 1) % 32 == 0: |
|
diff_renderer.rasterizer.reset() |
|
losses = {} |
|
model_output = body_model( |
|
return_verts=True, |
|
jaw_pose=cur_jaw, |
|
leye_pose=cur_leyes, |
|
reye_pose=cur_reyes, |
|
expression=cur_exp, |
|
betas=body_params_dict['betas'], |
|
transl=cur_transl, |
|
body_pose=cur_pose, |
|
global_orient=cur_global_orient, |
|
left_hand_pose=cur_left_hand_pose, |
|
right_hand_pose=cur_right_hand_pose, |
|
) |
|
|
|
vertices = model_output.vertices[:, |
|
smplx2flame_idx |
|
.long(), :] |
|
vertices = vertices[valid_bool.bool(), ...] |
|
|
|
lmk68_all = model_output.joints[:, 67:67 + 51 + |
|
17, :] |
|
lmk68 = lmk68_all[valid_bool.bool(), ...] |
|
|
|
proj_lmks = cameras.transform_points_screen( |
|
lmk68)[:, :, :2] |
|
proj_lmks = torch.cat([ |
|
proj_lmks[:, -17:, :], |
|
proj_lmks[:, :-17, :] |
|
], |
|
dim=1) |
|
|
|
I = torch.eye(3)[None].to(device) |
|
|
|
if pre_frame_exp is not None and start_frame != 0: |
|
losses['pre_exp'] = 0.001 * torch.sum( |
|
(pre_frame_exp - cur_exp[0])**2) |
|
|
|
if False: |
|
linear_rot_left = (axis_angle_to_matrix( |
|
cur_leyes[valid_bool.bool(), ...])) |
|
linear_rot_right = (axis_angle_to_matrix( |
|
cur_reyes[valid_bool.bool(), ...])) |
|
losses['eyes_sym_reg'] = torch.sum( |
|
(linear_rot_right - linear_rot_left)** |
|
2) / opt_bs |
|
losses['eyes_left_reg'] = torch.sum( |
|
(I - linear_rot_left)**2) / opt_bs |
|
losses['eyes_right_reg'] = torch.sum( |
|
(I - linear_rot_right)**2) / opt_bs |
|
|
|
w_lmks = tracker_cfg.w_lmks |
|
losses['lmk'] = SHOW.utils.lmk_loss( |
|
proj_lmks, image_lmks, |
|
image_size) * w_lmks * 8.0 |
|
losses[ |
|
'lmk_mount'] = SHOW.utils.mouth_loss( |
|
proj_lmks, image_lmks, |
|
image_size) * w_lmks * 4.0 * 4 |
|
losses['lmk_oval'] = SHOW.utils.lmk_loss( |
|
proj_lmks[:, :17, ...], image_lmks[:, :17, |
|
...], |
|
image_size) * w_lmks |
|
|
|
losses['jaw_reg'] = torch.sum( |
|
(I - axis_angle_to_matrix( |
|
cur_jaw[valid_bool.bool(), ...]))** |
|
2) * 16.0 / opt_bs |
|
losses['exp_reg'] = torch.sum( |
|
cur_exp[valid_bool.bool(), |
|
...]**2) * 0.01 / opt_bs |
|
|
|
if use_shared_tex: |
|
losses['tex_reg'] = torch.sum(cur_tex** |
|
2) * 0.02 |
|
else: |
|
losses['tex_reg'] = torch.sum( |
|
cur_tex[valid_bool.bool(), |
|
...]**2) * 0.02 / opt_bs |
|
|
|
def temporary_loss(o_w, i_w, gmof, param): |
|
assert param.shape[ |
|
0] > 2, f'optimize batchsize must > 2 to enable temporary smooth' |
|
return (o_w**2) * (gmof( |
|
i_w * |
|
(param[2:, ...] + param[:-2, ...] - |
|
2 * param[1:-1, ...]))).mean() |
|
|
|
def pow(x): |
|
return x.pow(2) |
|
|
|
if cur_exp.shape[0] > 2: |
|
losses['loss_sexp'] = temporary_loss( |
|
1.0, 2.0, pow, cur_exp) |
|
losses['loss_sjaw'] = temporary_loss( |
|
1.0, 2.0, pow, cur_jaw) |
|
|
|
def k_fun(k): |
|
return tracker_cfg.w_pho * 32.0 if k > 0 else tracker_cfg.w_pho |
|
|
|
albedos = flametex(cur_tex) / 255. |
|
|
|
if use_shared_tex: |
|
albedos = albedos.expand( |
|
valid_bs, -1, -1, -1) |
|
else: |
|
albedos = albedos[valid_bool.bool(), ...] |
|
|
|
ops = diff_renderer( |
|
vertices, albedos, |
|
cur_sh[valid_bool.bool(), ...], cameras) |
|
|
|
grid = ops['position_images'].permute( |
|
0, 2, 3, 1)[:, :, :, :2] |
|
sampled_image = F.grid_sample( |
|
flipped, grid, align_corners=False) |
|
ops_mask = SHOW.utils.parse_mask(ops) |
|
tmp_img = ops['images'] |
|
|
|
losses['pho'] = SHOW.utils.pixel_loss( |
|
tmp_img, sampled_image, |
|
ops_mask) * k_fun(k) |
|
|
|
all_loss = 0. |
|
for key in losses.keys(): |
|
all_loss = all_loss + losses[key] |
|
losses['all_loss'] = all_loss |
|
|
|
log_str = SHOW.print_dict_losses(losses) |
|
|
|
if report_wandb: |
|
if globals().get('wandb', None) is None: |
|
os.environ[ |
|
'WANDB_API_KEY'] = 'xxx' |
|
os.environ['WANDB_NAME'] = 'tracker' |
|
import wandb |
|
wandb.init( |
|
reinit=True, |
|
resume='allow', |
|
project='tracker', |
|
) |
|
globals()['wandb'] = wandb |
|
|
|
if globals().get('wandb', |
|
None) is not None: |
|
globals()['wandb'].log(losses) |
|
|
|
if save_traing_img: |
|
|
|
def save_callback(frame, final_views): |
|
cur_idx = (frame + opt_bs * opt_idx) |
|
if cur_idx in observe_idx_list: |
|
|
|
observe_idx_frame_dir = os.path.join( |
|
save_traing_img_dir, |
|
f'{cur_idx:03d}') |
|
Path(observe_idx_frame_dir).mkdir( |
|
parents=True, exist_ok=True) |
|
|
|
cv2.imwrite( |
|
os.path.join( |
|
observe_idx_frame_dir, |
|
f'{k}_{p}.jpg'), |
|
final_views) |
|
|
|
save_tracker( |
|
img, |
|
valid_bool, |
|
valid_bs, |
|
ops, |
|
vertices, |
|
cameras, |
|
image_lmks, |
|
proj_lmks, |
|
flame_faces, |
|
mesh_rasterizer, |
|
debug_renderer, |
|
save_callback, |
|
) |
|
|
|
if loggers is not None: |
|
loggers.log_bs(losses) |
|
if torch.isnan(all_loss).sum(): |
|
loggers.alert( |
|
title='Nan error', |
|
msg= |
|
f'tracker nan in: {tracker_cfg.ours_output_folder}' |
|
) |
|
open( |
|
tracker_cfg.ours_output_folder + |
|
'/mica_opt_nan.info', 'a').close() |
|
break |
|
|
|
else: |
|
pbar.set_description(log_str) |
|
pbar.update(1) |
|
|
|
optimizer.zero_grad() |
|
all_loss.backward() |
|
optimizer.step() |
|
|
|
if all_loss.item() < best_loss: |
|
best_loss = all_loss.item() |
|
opt_bs_tex = cur_tex.clone().detach() |
|
opt_bs_sh = cur_sh.clone().detach() |
|
body_params_dict['expression'][ |
|
start_frame:end_frame] = cur_exp.clone( |
|
).detach() |
|
body_params_dict['leye_pose'][ |
|
start_frame: |
|
end_frame] = cur_leyes.clone().detach( |
|
) |
|
body_params_dict['reye_pose'][ |
|
start_frame: |
|
end_frame] = cur_reyes.clone().detach( |
|
) |
|
body_params_dict['jaw_pose'][ |
|
start_frame:end_frame] = cur_jaw.clone( |
|
).detach() |
|
body_params_dict['body_pose_axis'][ |
|
start_frame: |
|
end_frame] = cur_pose.clone().detach( |
|
).squeeze() |
|
|
|
|
|
with default_timers['saving_stage']: |
|
if save_traing_img: |
|
for idx in observe_idx_list: |
|
observe_idx_frame_dir = os.path.join( |
|
save_traing_img_dir, f'{idx:03d}') |
|
Path(observe_idx_frame_dir).mkdir(parents=True, |
|
exist_ok=True) |
|
|
|
if not SHOW.is_empty_dir(observe_idx_frame_dir): |
|
images_to_video( |
|
input_folder=observe_idx_frame_dir, |
|
output_path=observe_idx_frame_dir + '.mp4', |
|
img_format=None, |
|
fps=30, |
|
) |
|
|
|
dict_to_save = dict( |
|
expression=body_params_dict['expression'] |
|
[start_frame:end_frame].clone().detach().cpu().numpy(), |
|
leye_pose=body_params_dict['leye_pose'] |
|
[start_frame:end_frame].clone().detach().cpu().numpy(), |
|
reye_pose=body_params_dict['reye_pose'] |
|
[start_frame:end_frame].clone().detach().cpu().numpy(), |
|
jaw_pose=body_params_dict['jaw_pose'] |
|
[start_frame:end_frame].clone().detach().cpu().numpy(), |
|
body_pose_axis=body_params_dict['body_pose_axis'] |
|
[start_frame:end_frame].clone().detach().cpu().numpy(), |
|
tex=opt_bs_tex.clone().detach().cpu().numpy(), |
|
sh=opt_bs_sh.clone().detach().cpu().numpy(), |
|
) |
|
mmcv.dump(dict_to_save, mica_part_pkl_path) |
|
logger.info(f'mica pkl part path: {mica_part_pkl_path}') |
|
pre_frame_exp = dict_to_save['expression'][-1] |
|
pre_frame_exp = torch.Tensor(pre_frame_exp).to(device) |
|
vertices_ = model_output.vertices.clone().detach().cpu() |
|
logger.info( |
|
f'mica render to origin path: {tracker_cfg.mica_org_out_path}') |
|
|
|
import platform |
|
if platform.system() == "Linux": |
|
os.environ['PYOPENGL_PLATFORM'] = 'egl' |
|
else: |
|
if 'PYOPENGL_PLATFORM' in os.environ: |
|
os.environ.__delitem__('PYOPENGL_PLATFORM') |
|
|
|
import pyrender |
|
input_renderer = pyrender.OffscreenRenderer(viewport_width=width, |
|
viewport_height=height, |
|
point_size=1.0) |
|
|
|
for idx in tqdm(range(vertices_.shape[0]), |
|
desc='saving ours final pyrender images'): |
|
|
|
cur_idx = idx + start_frame + 1 |
|
|
|
input_img = SHOW.find_full_impath_by_name( |
|
root=tracker_cfg.img_folder, name=f'{cur_idx:06d}') |
|
output_name = os.path.join( |
|
tracker_cfg.mica_org_out_path, |
|
f"{cur_idx:06}.{tracker_cfg.output_img_ext}") |
|
|
|
camera_pose = op.get_smplx_to_pyrender_K(camera_transl) |
|
|
|
meta_data = dict( |
|
input_img=input_img, |
|
output_name=output_name, |
|
) |
|
|
|
save_one_results( |
|
vertices_[idx], |
|
body_model.faces, |
|
img_size=(height, width), |
|
center=center, |
|
focal_length=[focal_length, focal_length], |
|
camera_pose=camera_pose, |
|
meta_data=meta_data, |
|
input_renderer=input_renderer, |
|
) |
|
input_renderer.delete() |
|
|
|
|
|
if tracker_cfg.save_final_vis: |
|
|
|
def save_callback(frame, final_views): |
|
cur_idx = (frame + opt_bs * opt_idx) |
|
|
|
if loggers is not None: |
|
loggers.log_image(f"final_mica_img/{cur_idx:03d}", |
|
final_views / 255.0) |
|
|
|
cv2.imwrite( |
|
os.path.join(tracker_cfg.mica_save_path, |
|
f'{cur_idx:03d}.jpg'), final_views) |
|
|
|
if True: |
|
save_tracker( |
|
img, |
|
valid_bool, |
|
valid_bs, |
|
ops, |
|
vertices, |
|
cameras, |
|
image_lmks, |
|
proj_lmks, |
|
flame_faces, |
|
mesh_rasterizer, |
|
debug_renderer, |
|
save_callback, |
|
) |
|
|
|
load_data = mmcv.load(tracker_cfg.ours_pkl_file_path)[0] |
|
load_data = SHOW.replace_mica_exp(tracker_cfg.mica_all_dir, load_data) |
|
mmcv.dump([load_data], tracker_cfg.mica_merge_pkl) |
|
|
|
if not Path(tracker_cfg.mica_org_out_video).exists(): |
|
if not SHOW.is_empty_dir(tracker_cfg.mica_org_out_path): |
|
images_to_video( |
|
input_folder=tracker_cfg.mica_org_out_path, |
|
output_path=tracker_cfg.mica_org_out_video, |
|
img_format=None, |
|
fps=30, |
|
) |
|
if not Path(tracker_cfg.mica_grid_video).exists(): |
|
if not SHOW.is_empty_dir(tracker_cfg.mica_save_path): |
|
images_to_video( |
|
input_folder=tracker_cfg.mica_save_path, |
|
output_path=tracker_cfg.mica_grid_video, |
|
img_format=None, |
|
fps=30, |
|
) |
|
|