GVHMR / app /demo.py
IsshikiHugh's picture
update
2e7c76e
import cv2
import torch
import pytorch_lightning as pl
import numpy as np
import argparse
from hmr4d.utils.pylogger import Log
import hydra
from hydra import initialize_config_module, compose
from pathlib import Path
from pytorch3d.transforms import quaternion_to_matrix
from hmr4d.configs import register_store_gvhmr
from hmr4d.utils.video_io_utils import (
get_video_lwh,
read_video_np,
save_video,
merge_videos_horizontal,
get_writer,
get_video_reader,
)
from hmr4d.utils.vis.cv2_utils import draw_bbx_xyxy_on_image_batch, draw_coco17_skeleton_batch
from hmr4d.utils.preproc import Tracker, Extractor, VitPoseExtractor, SLAMModel
from hmr4d.utils.geo.hmr_cam import get_bbx_xys_from_xyxy, estimate_K, convert_K_to_K4, create_camera_sensor
from hmr4d.utils.geo_transform import compute_cam_angvel
from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL
from hmr4d.utils.net_utils import detach_to_cpu, to_cuda
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.utils.vis.renderer import Renderer, get_global_cameras_static, get_ground_params_from_points
from tqdm import tqdm
from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay
from einops import einsum, rearrange
CRF = 23 # 17 is lossless, every +6 halves the mp4 size
def parse_args_to_cfg():
# Put all args to cfg
parser = argparse.ArgumentParser()
parser.add_argument("--video", type=str, default="inputs/demo/dance_3.mp4")
parser.add_argument("--output_root", type=str, default=None, help="by default to outputs/demo")
parser.add_argument("-s", "--static_cam", action="store_true", help="If true, skip DPVO")
parser.add_argument("--verbose", action="store_true", help="If true, draw intermediate results")
args = parser.parse_args()
# Input
video_path = Path(args.video)
assert video_path.exists(), f"Video not found at {video_path}"
length, width, height = get_video_lwh(video_path)
Log.info(f"[Input]: {video_path}")
Log.info(f"(L, W, H) = ({length}, {width}, {height})")
# Cfg
with initialize_config_module(version_base="1.3", config_module=f"hmr4d.configs"):
overrides = [
f"video_name={video_path.stem}",
f"static_cam={args.static_cam}",
f"verbose={args.verbose}",
]
# Allow to change output root
if args.output_root is not None:
overrides.append(f"output_root={args.output_root}")
register_store_gvhmr()
cfg = compose(config_name="demo", overrides=overrides)
# Output
Log.info(f"[Output Dir]: {cfg.output_dir}")
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.preprocess_dir).mkdir(parents=True, exist_ok=True)
# Copy raw-input-video to video_path
Log.info(f"[Copy Video] {video_path} -> {cfg.video_path}")
if not Path(cfg.video_path).exists() or get_video_lwh(video_path)[0] != get_video_lwh(cfg.video_path)[0]:
reader = get_video_reader(video_path)
writer = get_writer(cfg.video_path, fps=30, crf=CRF)
for img in tqdm(reader, total=get_video_lwh(video_path)[0], desc=f"Copy"):
writer.write_frame(img)
writer.close()
reader.close()
return cfg
@torch.no_grad()
def run_preprocess(cfg, progress):
Log.info(f"[Preprocess] Start!")
tic = Log.time()
video_path = cfg.video_path
paths = cfg.paths
static_cam = cfg.static_cam
verbose = cfg.verbose
# Get bbx tracking result
progress(0, '[Preprocess] YoloV8 Tracking')
if not Path(paths.bbx).exists():
tracker = Tracker()
bbx_xyxy = tracker.get_one_track(video_path).float() # (L, 4)
bbx_xys = get_bbx_xys_from_xyxy(bbx_xyxy, base_enlarge=1.2).float() # (L, 3) apply aspect ratio and enlarge
torch.save({"bbx_xyxy": bbx_xyxy, "bbx_xys": bbx_xys}, paths.bbx)
del tracker
else:
bbx_xys = torch.load(paths.bbx)["bbx_xys"]
Log.info(f"[Preprocess] bbx (xyxy, xys) from {paths.bbx}")
if verbose:
video = read_video_np(video_path)
bbx_xyxy = torch.load(paths.bbx)["bbx_xyxy"]
video_overlay = draw_bbx_xyxy_on_image_batch(bbx_xyxy, video)
save_video(video_overlay, cfg.paths.bbx_xyxy_video_overlay)
# Get VitPose
progress(1/4, '[Preprocess] ViTPose')
if not Path(paths.vitpose).exists():
vitpose_extractor = VitPoseExtractor()
vitpose = vitpose_extractor.extract(video_path, bbx_xys)
torch.save(vitpose, paths.vitpose)
del vitpose_extractor
else:
vitpose = torch.load(paths.vitpose)
Log.info(f"[Preprocess] vitpose from {paths.vitpose}")
if verbose:
video = read_video_np(video_path)
video_overlay = draw_coco17_skeleton_batch(video, vitpose, 0.5)
save_video(video_overlay, paths.vitpose_video_overlay)
# Get vit features
progress(2/4, '[Preprocess] HMR2 Feature')
if not Path(paths.vit_features).exists():
extractor = Extractor()
vit_features = extractor.extract_video_features(video_path, bbx_xys)
torch.save(vit_features, paths.vit_features)
del extractor
else:
Log.info(f"[Preprocess] vit_features from {paths.vit_features}")
# Get DPVO results
progress(3/4, '[Preprocess] DPVO')
if not static_cam: # use slam to get cam rotation
if not Path(paths.slam).exists():
length, width, height = get_video_lwh(cfg.video_path)
K_fullimg = estimate_K(width, height)
intrinsics = convert_K_to_K4(K_fullimg)
slam = SLAMModel(video_path, width, height, intrinsics, buffer=4000, resize=0.5)
bar = tqdm(total=length, desc="DPVO")
while True:
ret = slam.track()
if ret:
bar.update()
else:
break
slam_results = slam.process() # (L, 7), numpy
torch.save(slam_results, paths.slam)
else:
Log.info(f"[Preprocess] slam results from {paths.slam}")
Log.info(f"[Preprocess] End. Time elapsed: {Log.time()-tic:.2f}s")
def load_data_dict(cfg):
paths = cfg.paths
length, width, height = get_video_lwh(cfg.video_path)
if cfg.static_cam:
R_w2c = torch.eye(3).repeat(length, 1, 1)
else:
traj = torch.load(cfg.paths.slam)
traj_quat = torch.from_numpy(traj[:, [6, 3, 4, 5]])
R_w2c = quaternion_to_matrix(traj_quat).mT
K_fullimg = estimate_K(width, height).repeat(length, 1, 1)
# K_fullimg = create_camera_sensor(width, height, 26)[2].repeat(length, 1, 1)
data = {
"length": torch.tensor(length),
"bbx_xys": torch.load(paths.bbx)["bbx_xys"],
"kp2d": torch.load(paths.vitpose),
"K_fullimg": K_fullimg,
"cam_angvel": compute_cam_angvel(R_w2c),
"f_imgseq": torch.load(paths.vit_features),
}
return data
def render_incam(cfg, pred, smpl_utils):
incam_video_path = Path(cfg.paths.incam_video)
if incam_video_path.exists():
Log.info(f"[Render Incam] Video already exists at {incam_video_path}")
return
# pred = torch.load(cfg.paths.hmr4d_results)
smplx = smpl_utils['smplx']
smplx2smpl = smpl_utils['smplx2smpl']
faces_smpl = smpl_utils['faces_smpl']
# smpl
smplx_out = smplx(**to_cuda(pred["smpl_params_incam"]))
pred_c_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices])
# -- rendering code -- #
video_path = cfg.video_path
length, width, height = get_video_lwh(video_path)
K = pred["K_fullimg"][0]
# renderer
renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K)
reader = get_video_reader(video_path) # (F, H, W, 3), uint8, numpy
bbx_xys_render = torch.load(cfg.paths.bbx)["bbx_xys"]
# -- render mesh -- #
verts_incam = pred_c_verts
writer = get_writer(incam_video_path, fps=30, crf=CRF)
for i, img_raw in tqdm(enumerate(reader), total=get_video_lwh(video_path)[0], desc=f"Rendering Incam"):
img = renderer.render_mesh(verts_incam[i].cuda(), img_raw, [0.8, 0.8, 0.8])
# # bbx
# bbx_xys_ = bbx_xys_render[i].cpu().numpy()
# lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int)
# rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int)
# img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2)
writer.write_frame(img)
writer.close()
reader.close()
def render_global(cfg, pred, smpl_utils):
global_video_path = Path(cfg.paths.global_video)
if global_video_path.exists():
Log.info(f"[Render Global] Video already exists at {global_video_path}")
return
debug_cam = False
# pred = torch.load(cfg.paths.hmr4d_results)
smplx = smpl_utils['smplx']
smplx2smpl = smpl_utils['smplx2smpl']
faces_smpl = smpl_utils['faces_smpl']
J_regressor = smpl_utils['J_regressor']
# smpl
smplx_out = smplx(**to_cuda(pred["smpl_params_global"]))
pred_ay_verts = torch.stack([torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices])
def move_to_start_point_face_z(verts):
"XZ to origin, Start from the ground, Face-Z"
# position
verts = verts.clone() # (L, V, 3)
offset = einsum(J_regressor, verts[0], "j v, v i -> j i")[0] # (3)
offset[1] = verts[:, :, [1]].min()
verts = verts - offset
# face direction
T_ay2ayfz = compute_T_ayfz2ay(einsum(J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True)
verts = apply_T_on_points(verts, T_ay2ayfz)
return verts
verts_glob = move_to_start_point_face_z(pred_ay_verts)
joints_glob = einsum(J_regressor, verts_glob, "j v, l v i -> l j i") # (L, J, 3)
global_R, global_T, global_lights = get_global_cameras_static(
verts_glob.cpu(),
beta=2.0,
cam_height_degree=20,
target_center_height=1.0,
)
# -- rendering code -- #
video_path = cfg.video_path
length, width, height = get_video_lwh(video_path)
_, _, K = create_camera_sensor(width, height, 24) # render as 24mm lens
# renderer
renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K)
# renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K, bin_size=0)
# -- render mesh -- #
scale, cx, cz = get_ground_params_from_points(joints_glob[:, 0], verts_glob)
renderer.set_ground(scale * 1.5, cx, cz)
color = torch.ones(3).float().cuda() * 0.8
render_length = length if not debug_cam else 8
writer = get_writer(global_video_path, fps=30, crf=CRF)
for i in tqdm(range(render_length), desc=f"Rendering Global"):
cameras = renderer.create_camera(global_R[i], global_T[i])
img = renderer.render_with_ground(verts_glob[[i]], color[None], cameras, global_lights)
writer.write_frame(img)
writer.close()
if __name__ == "__main__":
cfg = parse_args_to_cfg()
paths = cfg.paths
Log.info(f"[GPU]: {torch.cuda.get_device_name()}")
Log.info(f'[GPU]: {torch.cuda.get_device_properties("cuda")}')
# ===== Preprocess and save to disk ===== #
run_preprocess(cfg)
data = load_data_dict(cfg)
# ===== HMR4D ===== #
if not Path(paths.hmr4d_results).exists():
Log.info("[HMR4D] Predicting")
model: DemoPL = hydra.utils.instantiate(cfg.model, _recursive_=False)
model.load_pretrained_model(cfg.ckpt_path)
model = model.eval().cuda()
tic = Log.sync_time()
pred = model.predict(data, static_cam=cfg.static_cam)
pred = detach_to_cpu(pred)
data_time = data["length"] / 30
Log.info(f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s")
torch.save(pred, paths.hmr4d_results)
# ===== Render ===== #
render_incam(cfg)
render_global(cfg)
if not Path(paths.incam_global_horiz_video).exists():
Log.info("[Merge Videos]")
merge_videos_horizontal([paths.incam_video, paths.global_video], paths.incam_global_horiz_video)