ECON / apps /benchmark.py
Yuliang's picture
testing
de4d7c5
raw
history blame
11.5 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import warnings
import logging
warnings.filterwarnings("ignore")
logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("trimesh").setLevel(logging.ERROR)
import torch
import argparse
import os
from termcolor import colored
from tqdm.auto import tqdm
from apps.Normal import Normal
from apps.IFGeo import IFGeo
from lib.common.config import cfg
from lib.common.BNI import BNI
from lib.common.BNI_utils import save_normal_tensor
from lib.dataset.EvalDataset import EvalDataset
from lib.dataset.Evaluator import Evaluator
from lib.dataset.mesh_util import *
from lib.common.voxelize import VoxelGrid
torch.backends.cudnn.benchmark = True
speed_analysis = False
if __name__ == "__main__":
if speed_analysis:
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
# loading cfg file
parser = argparse.ArgumentParser()
parser.add_argument("-gpu", "--gpu_device", type=int, default=0)
parser.add_argument("-ifnet", action="store_true")
parser.add_argument("-cfg", "--config", type=str, default="./configs/econ.yaml")
args = parser.parse_args()
# cfg read and merge
cfg.merge_from_file(args.config)
device = torch.device("cuda:0")
cfg_test_list = [
"dataset.rotation_num", 3, "bni.use_smpl", ["hand"], "bni.use_ifnet", args.ifnet,
"bni.cut_intersection", True,
]
# # if w/ RenderPeople+CAPE
# cfg_test_list += ["dataset.types", ["cape", "renderpeople"], "dataset.scales", [100.0, 1.0]]
# if only w/ CAPE
cfg_test_list += ["dataset.types", ["cape"], "dataset.scales", [100.0]]
cfg.merge_from_list(cfg_test_list)
cfg.freeze()
# load normal model
normal_net = Normal.load_from_checkpoint(
cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False
)
normal_net = normal_net.to(device)
normal_net.netG.eval()
print(
colored(
f"Resume Normal Estimator from {Format.start} {cfg.normal_path} {Format.end}", "green"
)
)
# SMPLX object
SMPLX_object = SMPLX()
dataset = EvalDataset(cfg=cfg, device=device)
evaluator = Evaluator(device=device)
export_dir = osp.join(cfg.results_path, cfg.name, "IF-Net+" if cfg.bni.use_ifnet else "SMPL-X")
print(colored(f"Dataset Size: {len(dataset)}", "green"))
if cfg.bni.use_ifnet:
# load IFGeo model
ifnet = IFGeo.load_from_checkpoint(
cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False
)
ifnet = ifnet.to(device)
ifnet.netG.eval()
print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green"))
print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green"))
else:
print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green"))
pbar = tqdm(dataset)
benchmark = {}
for data in pbar:
for key in data.keys():
if torch.is_tensor(data[key]):
data[key] = data[key].unsqueeze(0).to(device)
is_smplx = True if 'smplx_path' in data.keys() else False
# filenames and makedirs
current_name = f"{data['dataset']}-{data['subject']}-{data['rotation']:03d}"
current_dir = osp.join(export_dir, data['dataset'], data['subject'])
os.makedirs(current_dir, exist_ok=True)
final_path = osp.join(current_dir, f"{current_name}_final.obj")
if not osp.exists(final_path):
in_tensor = data.copy()
batch_smpl_verts = in_tensor["smpl_verts"].detach()
batch_smpl_verts *= torch.tensor([1.0, -1.0, 1.0]).to(device)
batch_smpl_faces = in_tensor["smpl_faces"].detach()
in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
batch_smpl_verts, batch_smpl_faces
)
with torch.no_grad():
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
smpl_mesh = trimesh.Trimesh(
batch_smpl_verts.cpu().numpy()[0],
batch_smpl_faces.cpu().numpy()[0]
)
side_mesh = smpl_mesh.copy()
face_mesh = smpl_mesh.copy()
hand_mesh = smpl_mesh.copy()
smplx_mesh = smpl_mesh.copy()
# save normals, depths and masks
BNI_dict = save_normal_tensor(
in_tensor,
0,
osp.join(current_dir, "BNI/param_dict"),
cfg.bni.thickness if data['dataset'] == 'renderpeople' else 0.0,
)
# BNI process
BNI_object = BNI(
dir_path=osp.join(current_dir, "BNI"),
name=current_name,
BNI_dict=BNI_dict,
cfg=cfg.bni,
device=device
)
BNI_object.extract_surface(False)
if is_smplx:
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
if cfg.bni.use_ifnet:
# mesh completion via IF-net
in_tensor.update(
dataset.depth_to_voxel(
{
"depth_F": BNI_object.F_depth.unsqueeze(0).to(device),
"depth_B": BNI_object.B_depth.unsqueeze(0).to(device)
}
)
)
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
0,
] * 3, scale=2.0).data.transpose(2, 1, 0)
occupancies = np.flip(occupancies, axis=1)
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()
).float().unsqueeze(0).to(device)
with torch.no_grad():
sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor)
verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf)
if ifnet.clean_mesh_flag:
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
side_mesh_path = osp.join(current_dir, f"{current_name}_IF.obj")
side_mesh = remesh_laplacian(trimesh.Trimesh(verts_IF, faces_IF), side_mesh_path)
full_lst = []
if "hand" in cfg.bni.use_smpl:
# only hands
if is_smplx:
hand_mesh = apply_vertex_mask(hand_mesh, SMPLX_object.smplx_mano_vertex_mask)
else:
hand_mesh = apply_vertex_mask(hand_mesh, SMPLX_object.smpl_mano_vertex_mask)
# remove hand neighbor triangles
BNI_object.F_B_trimesh = part_removal(
BNI_object.F_B_trimesh,
hand_mesh,
cfg.bni.hand_thres,
device,
smplx_mesh,
region="hand"
)
side_mesh = part_removal(
side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand"
)
# hand_mesh.export(osp.join(current_dir, f"{current_name}_hands.obj"))
full_lst += [hand_mesh]
full_lst += [BNI_object.F_B_trimesh]
# initial side_mesh could be SMPLX or IF-net
side_mesh = part_removal(
side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False
)
full_lst += [side_mesh]
if cfg.bni.use_poisson:
final_mesh = poisson(
sum(full_lst),
final_path,
cfg.bni.poisson_depth,
)
else:
final_mesh = sum(full_lst)
final_mesh.export(final_path)
else:
final_mesh = trimesh.load(final_path)
# evaluation
metric_path = osp.join(export_dir, "metric.npy")
if osp.exists(metric_path):
benchmark = np.load(metric_path, allow_pickle=True).item()
if benchmark == {} or data["dataset"] not in benchmark.keys(
) or f"{data['subject']}-{data['rotation']}" not in benchmark[data["dataset"]]["subject"]:
result_eval = {
"verts_gt": data["verts"][0],
"faces_gt": data["faces"][0],
"verts_pr": final_mesh.vertices,
"faces_pr": final_mesh.faces,
"calib": data["calib"][0],
}
evaluator.set_mesh(result_eval, scale=False)
chamfer, p2s = evaluator.calculate_chamfer_p2s(num_samples=1000)
nc = evaluator.calculate_normal_consist(osp.join(current_dir, f"{current_name}_nc.png"))
if data["dataset"] not in benchmark.keys():
benchmark[data["dataset"]] = {
"chamfer": [chamfer.item()],
"p2s": [p2s.item()],
"nc": [nc.item()],
"subject": [f"{data['subject']}-{data['rotation']}"],
"total": 1,
}
else:
benchmark[data["dataset"]]["chamfer"] += [chamfer.item()]
benchmark[data["dataset"]]["p2s"] += [p2s.item()]
benchmark[data["dataset"]]["nc"] += [nc.item()]
benchmark[data["dataset"]]["subject"] += [f"{data['subject']}-{data['rotation']}"]
benchmark[data["dataset"]]["total"] += 1
np.save(metric_path, benchmark, allow_pickle=True)
else:
subject_idx = benchmark[data["dataset"]
]["subject"].index(f"{data['subject']}-{data['rotation']}")
chamfer = torch.tensor(benchmark[data["dataset"]]["chamfer"][subject_idx])
p2s = torch.tensor(benchmark[data["dataset"]]["p2s"][subject_idx])
nc = torch.tensor(benchmark[data["dataset"]]["nc"][subject_idx])
pbar.set_description(
f"{current_name} | {chamfer.item():.3f} | {p2s.item():.3f} | {nc.item():.4f}"
)
for dataset in benchmark.keys():
for metric in ["chamfer", "p2s", "nc"]:
print(
f"{dataset}-{metric}: {sum(benchmark[dataset][metric])/benchmark[dataset]['total']:.4f}"
)
if cfg.bni.use_ifnet:
print(colored("Finish evaluating on ECON_IF", "green"))
else:
print(colored("Finish evaluating of ECON_EX", "green"))
if speed_analysis:
profiler.disable()
profiler.dump_stats(osp.join(export_dir, "econ.stats"))
stats = pstats.Stats(osp.join(export_dir, "econ.stats"))
stats.sort_stats("cumtime").print_stats(10)