diff --git a/README.md b/README.md index 997b626d8f415f03e66d10c10f815290af50b344..70b48651eb55091f6c13dcae46cc44356ee8b939 100644 --- a/README.md +++ b/README.md @@ -103,20 +103,23 @@ python -m apps.avatarizer -n {filename} ### Some adjustable parameters in _config/econ.yaml_ -- `use_ifnet: True` - - True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality) - - False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed) +- `use_ifnet: False` + - True: use IF-Nets+ for mesh completion ( $\text{ECON}_\text{IF}$ - Better quality, **~2min / img**) + - False: use SMPL-X for mesh completion ( $\text{ECON}_\text{EX}$ - Faster speed, **~1.5min / img**) - `use_smpl: ["hand", "face"]` - [ ]: don't use either hands or face parts from SMPL-X - ["hand"]: only use the **visible** hands from SMPL-X - ["hand", "face"]: use both **visible** hands and face from SMPL-X - `thickness: 2cm` - could be increased accordingly in case final reconstruction **xx_full.obj** looks flat +- `k: 4` + - could be reduced accordingly in case the surface of **xx_full.obj** has discontinous artifacts - `hps_type: PIXIE` - "pixie": more accurate for face and hands - "pymafx": more robust for challenging poses -- `k: 4` - - could be reduced accordingly in case the surface of **xx_full.obj** has discontinous artifacts +- `texture_src: image` + - "image": direct mapping the aligned pixels to final mesh + - "SD": use Stable Diffusion to generate full texture (TODO)
@@ -160,7 +163,6 @@ Here are some great resources we benefit from: - [BiNI](https://github.com/hoshino042/bilateral_normal_integration) for Bilateral Normal Integration - [MonoPortDataset](https://github.com/Project-Splinter/MonoPortDataset) for Data Processing, [MonoPort](https://github.com/Project-Splinter/MonoPort) for fast implicit surface query - [rembg](https://github.com/danielgatis/rembg) for Human Segmentation -- [pypoisson](https://github.com/mmolero/pypoisson) for poisson reconstruction - [MediaPipe](https://google.github.io/mediapipe/getting_started/python.html) for full-body landmark estimation - [PyTorch-NICP](https://github.com/wuhaozhe/pytorch-nicp) for non-rigid registration - [smplx](https://github.com/vchoutas/smplx), [PyMAF-X](https://www.liuyebin.com/pymaf-x/), [PIXIE](https://github.com/YadiraF/PIXIE) for Human Pose & Shape Estimation diff --git a/apps/IFGeo.py b/apps/IFGeo.py index 966462dd93c3aedd84567b392f0ff9f876321bfc..8cb033d8d3fbd597ac80526d3d0c691451975685 100644 --- a/apps/IFGeo.py +++ b/apps/IFGeo.py @@ -24,7 +24,6 @@ torch.backends.cudnn.benchmark = True class IFGeo(pl.LightningModule): - def __init__(self, cfg): super(IFGeo, self).__init__() @@ -44,14 +43,15 @@ class IFGeo(pl.LightningModule): from lib.net.IFGeoNet_nobody import IFGeoNet self.netG = IFGeoNet(cfg) - - self.resolutions = (np.logspace( - start=5, - stop=np.log2(self.mcube_res), - base=2, - num=int(np.log2(self.mcube_res) - 4), - endpoint=True, - ) + 1.0) + self.resolutions = ( + np.logspace( + start=5, + stop=np.log2(self.mcube_res), + base=2, + num=int(np.log2(self.mcube_res) - 4), + endpoint=True, + ) + 1.0 + ) self.resolutions = self.resolutions.astype(np.int16).tolist() @@ -82,9 +82,9 @@ class IFGeo(pl.LightningModule): if self.cfg.optim == "Adadelta": - optimizer_G = torch.optim.Adadelta(optim_params_G, - lr=self.lr_G, - weight_decay=weight_decay) + optimizer_G = torch.optim.Adadelta( + optim_params_G, lr=self.lr_G, weight_decay=weight_decay + ) elif self.cfg.optim == "Adam": @@ -103,20 +103,14 @@ class IFGeo(pl.LightningModule): raise NotImplementedError # set scheduler - scheduler_G = torch.optim.lr_scheduler.MultiStepLR(optimizer_G, - milestones=self.cfg.schedule, - gamma=self.cfg.gamma) + scheduler_G = torch.optim.lr_scheduler.MultiStepLR( + optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma + ) return [optimizer_G], [scheduler_G] def training_step(self, batch, batch_idx): - # cfg log - if self.cfg.devices == 1: - if not self.cfg.fast_dev and self.global_step == 0: - export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg) - self.logger.experiment.config.update(convert_to_dict(self.cfg)) - self.netG.train() preds_G = self.netG(batch) @@ -127,12 +121,9 @@ class IFGeo(pl.LightningModule): "loss": error_G, } - self.log_dict(metrics_log, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - sync_dist=True) + self.log_dict( + metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True + ) return metrics_log @@ -143,12 +134,14 @@ class IFGeo(pl.LightningModule): "train/avgloss": batch_mean(outputs, "loss"), } - self.log_dict(metrics_log, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - rank_zero_only=True) + self.log_dict( + metrics_log, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True, + rank_zero_only=True + ) def validation_step(self, batch, batch_idx): @@ -162,12 +155,9 @@ class IFGeo(pl.LightningModule): "val/loss": error_G, } - self.log_dict(metrics_log, - prog_bar=True, - logger=False, - on_step=True, - on_epoch=False, - sync_dist=True) + self.log_dict( + metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True + ) return metrics_log @@ -178,9 +168,11 @@ class IFGeo(pl.LightningModule): "val/avgloss": batch_mean(outputs, "val/loss"), } - self.log_dict(metrics_log, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - rank_zero_only=True) + self.log_dict( + metrics_log, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True, + rank_zero_only=True + ) diff --git a/apps/Normal.py b/apps/Normal.py index a57df041fe1523a04a9ba9e58e0e88e43023ca1b..235c0aef05914ef040f1495843d339c758ebd9f3 100644 --- a/apps/Normal.py +++ b/apps/Normal.py @@ -1,14 +1,12 @@ from lib.net import NormalNet -from lib.common.train_util import convert_to_dict, export_cfg, batch_mean +from lib.common.train_util import batch_mean import torch import numpy as np -import os.path as osp from skimage.transform import resize import pytorch_lightning as pl class Normal(pl.LightningModule): - def __init__(self, cfg): super(Normal, self).__init__() self.cfg = cfg @@ -44,19 +42,19 @@ class Normal(pl.LightningModule): optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999)) optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999)) - scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_F, - milestones=self.cfg.schedule, - gamma=self.cfg.gamma) + scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR( + optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma + ) - scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_B, - milestones=self.cfg.schedule, - gamma=self.cfg.gamma) + scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR( + optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma + ) if 'gan' in self.ALL_losses: optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}] optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999)) - scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_D, - milestones=self.cfg.schedule, - gamma=self.cfg.gamma) + scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR( + optimizer_N_D, milestones=self.cfg.schedule, gamma=self.cfg.gamma + ) self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D] optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D] @@ -77,19 +75,16 @@ class Normal(pl.LightningModule): ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0), (height, height), anti_aliasing=True, - )) + ) + ) - self.logger.log_image(key=f"Normal/{dataset}/{idx if not self.overfit else 1}", - images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8) - ]) + self.logger.log_image( + key=f"Normal/{dataset}/{idx if not self.overfit else 1}", + images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8)] + ) def training_step(self, batch, batch_idx): - # cfg log - if not self.cfg.fast_dev and self.global_step == 0 and self.cfg.devices == 1: - export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg) - self.logger.experiment.config.update(convert_to_dict(self.cfg)) - self.netG.train() # retrieve the data @@ -125,7 +120,8 @@ class Normal(pl.LightningModule): opt_B.step() if batch_idx > 0 and batch_idx % int( - self.cfg.freq_show_train) == 0 and self.cfg.devices == 1: + self.cfg.freq_show_train + ) == 0 and self.cfg.devices == 1: self.netG.eval() with torch.no_grad(): @@ -142,12 +138,9 @@ class Normal(pl.LightningModule): for key in error_dict.keys(): metrics_log["train/loss_" + key] = error_dict[key].item() - self.log_dict(metrics_log, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - sync_dist=True) + self.log_dict( + metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True + ) return metrics_log @@ -163,12 +156,14 @@ class Normal(pl.LightningModule): loss_name = key metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key) - self.log_dict(metrics_log, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - rank_zero_only=True) + self.log_dict( + metrics_log, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True, + rank_zero_only=True + ) def validation_step(self, batch, batch_idx): @@ -212,9 +207,11 @@ class Normal(pl.LightningModule): [stage, loss_name] = key.split("/") metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key) - self.log_dict(metrics_log, - prog_bar=False, - logger=True, - on_step=False, - on_epoch=True, - rank_zero_only=True) + self.log_dict( + metrics_log, + prog_bar=False, + logger=True, + on_step=False, + on_epoch=True, + rank_zero_only=True + ) diff --git a/apps/avatarizer.py b/apps/avatarizer.py index 12c8fe781a78af9a4f586564f7b2826027b3d5f4..a601b1a60a2ee61c21936bb720f06ec87198d8d8 100644 --- a/apps/avatarizer.py +++ b/apps/avatarizer.py @@ -44,7 +44,8 @@ smpl_model = smplx.create( use_pca=False, num_betas=200, num_expression_coeffs=50, - ext='pkl') + ext='pkl' +) smpl_out_lst = [] @@ -62,7 +63,9 @@ for pose_type in ["t-pose", "da-pose", "pose"]: return_full_pose=True, return_joint_transformation=True, return_vertex_transformation=True, - pose_type=pose_type)) + pose_type=pose_type + ) + ) smpl_verts = smpl_out_lst[2].vertices.detach()[0] smpl_tree = cKDTree(smpl_verts.cpu().numpy()) @@ -74,7 +77,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da. econ_verts = torch.tensor(econ_obj.vertices).float() rot_mat_t = smpl_out_lst[2].vertex_transformation.detach()[0][idx[:, 0]] homo_coord = torch.ones_like(econ_verts)[..., :1] - econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord], dim=1).unsqueeze(-1) + econ_cano_verts = torch.inverse(rot_mat_t) @ torch.cat([econ_verts, homo_coord], + dim=1).unsqueeze(-1) econ_cano_verts = econ_cano_verts[:, :3, 0].cpu() econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces) @@ -84,7 +88,9 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da. econ_da = trimesh.Trimesh(econ_da_verts[:, :3, 0].cpu(), econ_obj.faces) # da-pose for SMPL-X - smpl_da = trimesh.Trimesh(smpl_out_lst[1].vertices.detach()[0], smpl_model.faces, maintain_orders=True, process=False) + smpl_da = trimesh.Trimesh( + smpl_out_lst[1].vertices.detach()[0], smpl_model.faces, maintain_orders=True, process=False + ) smpl_da.export(f"{prefix}_smpl_da.obj") # remove hands from ECON for next registeration @@ -97,7 +103,8 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da. # remove SMPL-X hand and face register_mask = ~np.isin( np.arange(smpl_da.vertices.shape[0]), - np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid])) + np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid]) + ) register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy() smpl_da_body = smpl_da.copy() smpl_da_body.update_faces(register_mask[smpl_da.faces].all(axis=1)) @@ -115,8 +122,13 @@ if not osp.exists(f"{prefix}_econ_da.obj") or not osp.exists(f"{prefix}_smpl_da. # remove over-streched+hand faces from ECON econ_da_body = econ_da.copy() edge_before = np.sqrt( - ((econ_obj.vertices[econ_cano.edges[:, 0]] - econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)) - edge_after = np.sqrt(((econ_da.vertices[econ_cano.edges[:, 0]] - econ_da.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1)) + ((econ_obj.vertices[econ_cano.edges[:, 0]] - + econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1) + ) + edge_after = np.sqrt( + ((econ_da.vertices[econ_cano.edges[:, 0]] - + econ_da.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1) + ) edge_diff = edge_after / edge_before.clip(1e-2) streched_mask = np.unique(econ_cano.edges[edge_diff > 6]) mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid) @@ -148,8 +160,9 @@ econ_J_regressor = (smpl_model.J_regressor[:, idx] * knn_weights[None]).sum(axis econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T num_posedirs = smpl_model.posedirs.shape[0] -econ_posedirs = (smpl_model.posedirs.view(num_posedirs, -1, 3)[:, idx, :] * - knn_weights[None, ..., None]).sum(axis=-2).view(num_posedirs, -1).float() +econ_posedirs = ( + smpl_model.posedirs.view(num_posedirs, -1, 3)[:, idx, :] * knn_weights[None, ..., None] +).sum(axis=-2).view(num_posedirs, -1).float() econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True) econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True) @@ -157,8 +170,9 @@ econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True) # re-compute da-pose rot_mat for ECON rot_mat_da = smpl_out_lst[1].vertex_transformation.detach()[0][idx[:, 0]] econ_da_verts = torch.tensor(econ_da.vertices).float() -econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat([econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]], - dim=1).unsqueeze(-1) +econ_cano_verts = torch.inverse(rot_mat_da) @ torch.cat( + [econ_da_verts, torch.ones_like(econ_da_verts)[..., :1]], dim=1 +).unsqueeze(-1) econ_cano_verts = econ_cano_verts[:, :3, 0].double() # ---------------------------------------------------- @@ -174,7 +188,8 @@ posed_econ_verts, _ = general_lbs( posedirs=econ_posedirs, J_regressor=econ_J_regressor, parents=smpl_model.parents, - lbs_weights=econ_lbs_weights) + lbs_weights=econ_lbs_weights +) econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_da.faces) -econ_pose.export(f"{prefix}_econ_pose.obj") \ No newline at end of file +econ_pose.export(f"{prefix}_econ_pose.obj") diff --git a/apps/infer.py b/apps/infer.py index 1e354b0e202771934a8a81e16c175489185e047a..fe88f0bbf64f1656fcf496ddce277f56384d1d20 100644 --- a/apps/infer.py +++ b/apps/infer.py @@ -34,7 +34,8 @@ from apps.IFGeo import IFGeo from pytorch3d.ops import SubdivideMeshes from lib.common.config import cfg from lib.common.render import query_color -from lib.common.train_util import init_loss, load_normal_networks, load_networks +from lib.common.train_util import init_loss, Format +from lib.common.imutils import blend_rgb_norm from lib.common.BNI import BNI from lib.common.BNI_utils import save_normal_tensor from lib.dataset.TestDataset import TestDataset @@ -68,20 +69,25 @@ if __name__ == "__main__": device = torch.device(f"cuda:{args.gpu_device}") # setting for testing on in-the-wild images - cfg_show_list = ["test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True, "batch_size", 1] + cfg_show_list = [ + "test_gpus", [args.gpu_device], "mcube_res", 512, "clean_mesh", True, "test_mode", True, + "batch_size", 1 + ] cfg.merge_from_list(cfg_show_list) cfg.freeze() - # load model - normal_model = Normal(cfg).to(device) - load_normal_networks(normal_model, cfg.normal_path) - normal_model.netG.eval() - - # load IFGeo model - ifnet_model = IFGeo(cfg).to(device) - load_networks(ifnet_model, mlp_path=cfg.ifnet_path) - ifnet_model.netG.eval() + # load normal model + normal_net = Normal.load_from_checkpoint( + cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False + ) + normal_net = normal_net.to(device) + normal_net.netG.eval() + print( + colored( + f"Resume Normal Estimator from {Format.start} {cfg.normal_path} {Format.end}", "green" + ) + ) # SMPLX object SMPLX_object = SMPLX() @@ -89,16 +95,24 @@ if __name__ == "__main__": dataset_param = { "image_dir": args.in_dir, "seg_dir": args.seg_dir, - "use_seg": True, # w/ or w/o segmentation - "hps_type": cfg.bni.hps_type, # pymafx/pixie + "use_seg": True, # w/ or w/o segmentation + "hps_type": cfg.bni.hps_type, # pymafx/pixie "vol_res": cfg.vol_res, "single": args.multi, } if cfg.bni.use_ifnet: - print(colored("Use IF-Nets (Implicit)+ for completion", "green")) + # load IFGeo model + ifnet = IFGeo.load_from_checkpoint( + cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False + ) + ifnet = ifnet.to(device) + ifnet.netG.eval() + + print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green")) + print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green")) else: - print(colored("Use SMPL-X (Explicit) for completion", "green")) + print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green")) dataset = TestDataset(dataset_param, device) @@ -125,13 +139,17 @@ if __name__ == "__main__": # 2. SMPL params (xxx_smpl.npy) # 3. d-BiNI surfaces (xxx_BNI.obj) # 4. seperate face/hand mesh (xxx_hand/face.obj) - # 5. full shape impainted by IF-Nets+, and remeshed shape (xxx_IF_(remesh).obj) + # 5. full shape impainted by IF-Nets+ after remeshing (xxx_IF.obj) # 6. sideded or occluded parts (xxx_side.obj) # 7. final reconstructed clothed human (xxx_full.obj) os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True) - in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask": data["img_mask"].to(device)} + in_tensor = { + "smpl_faces": data["smpl_faces"], + "image": data["img_icon"].to(device), + "mask": data["img_mask"].to(device) + } # The optimizer and variables optimed_pose = data["body_pose"].requires_grad_(True) @@ -139,7 +157,9 @@ if __name__ == "__main__": optimed_betas = data["betas"].requires_grad_(True) optimed_orient = data["global_orient"].requires_grad_(True) - optimizer_smpl = torch.optim.Adam([optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True) + optimizer_smpl = torch.optim.Adam( + [optimed_pose, optimed_trans, optimed_betas, optimed_orient], lr=1e-2, amsgrad=True + ) scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_smpl, mode="min", @@ -156,10 +176,12 @@ if __name__ == "__main__": smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj" + # remove this line if you change the loop_smpl and obtain different SMPL-X fits if osp.exists(smpl_path): smpl_verts_lst = [] smpl_faces_lst = [] + for idx in range(N_body): smpl_obj = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_{idx:02d}.obj" @@ -173,10 +195,12 @@ if __name__ == "__main__": batch_smpl_faces = torch.stack(smpl_faces_lst) # render optimized mesh as normal [-1,1] - in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(batch_smpl_verts, batch_smpl_faces) + in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( + batch_smpl_verts, batch_smpl_faces + ) with torch.no_grad(): - in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor) + in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor) in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device) in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]] @@ -194,8 +218,10 @@ if __name__ == "__main__": N_body, N_pose = optimed_pose.shape[:2] # 6d_rot to rot_mat - optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).view(N_body, 1, 3, 3) - optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).view(N_body, N_pose, 3, 3) + optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, + 6)).view(N_body, 1, 3, 3) + optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, + 6)).view(N_body, N_pose, 3, 3) smpl_verts, smpl_landmarks, smpl_joints = dataset.smpl_model( shape_params=optimed_betas, @@ -208,11 +234,16 @@ if __name__ == "__main__": ) smpl_verts = (smpl_verts + optimed_trans) * data["scale"] - smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor([1.0, 1.0, -1.0]).to(device) + smpl_joints = (smpl_joints + optimed_trans) * data["scale"] * torch.tensor( + [1.0, 1.0, -1.0] + ).to(device) # landmark errors - smpl_joints_3d = (smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0) * 0.5 - in_tensor["smpl_joint"] = smpl_joints[:, dataset.smpl_data.smpl_joint_ids_24_pixie, :] + smpl_joints_3d = ( + smpl_joints[:, dataset.smpl_data.smpl_joint_ids_45_pixie, :] + 1.0 + ) * 0.5 + in_tensor["smpl_joint"] = smpl_joints[:, + dataset.smpl_data.smpl_joint_ids_24_pixie, :] ghum_lmks = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], :2].to(device) ghum_conf = data["landmark"][:, SMPLX_object.ghum_smpl_pairs[:, 0], -1].to(device) @@ -227,7 +258,7 @@ if __name__ == "__main__": T_mask_F, T_mask_B = dataset.render.get_image(type="mask") with torch.no_grad(): - in_tensor["normal_F"], in_tensor["normal_B"] = normal_model.netG(in_tensor) + in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor) diff_F_smpl = torch.abs(in_tensor["T_normal_F"] - in_tensor["normal_F"]) diff_B_smpl = torch.abs(in_tensor["T_normal_B"] - in_tensor["normal_B"]) @@ -249,25 +280,37 @@ if __name__ == "__main__": # BUG: PyTorch3D silhouette renderer generates dilated mask bg_value = in_tensor["T_normal_F"][0, 0, 0, 0] - smpl_arr_fake = torch.cat([in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()], - dim=-1) + smpl_arr_fake = torch.cat( + [ + in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), + in_tensor["T_normal_B"][:, 0].ne(bg_value).float() + ], + dim=-1 + ) - body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2]) + body_overlap = (gt_arr * smpl_arr_fake.gt(0.0) + ).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2]) body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1) body_overlap_flag = body_overlap < cfg.body_overlap_thres - losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] + diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0 + losses["normal"]["value"] = ( + diff_F_smpl * body_overlap_mask[..., :512] + + diff_B_smpl * body_overlap_mask[..., 512:] + ).mean() / 2.0 losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag] occluded_idx = torch.where(body_overlap_flag)[0] ghum_conf[occluded_idx] *= ghum_conf[occluded_idx] > 0.95 - losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) * ghum_conf).mean(dim=1) + losses["joint"]["value"] = (torch.norm(ghum_lmks - smpl_lmks, dim=2) * + ghum_conf).mean(dim=1) # Weighted sum of the losses smpl_loss = 0.0 - pbar_desc = "Body Fitting --- " + pbar_desc = "Body Fitting -- " for k in ["normal", "silhouette", "joint"]: - per_loop_loss = (losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device)).mean() + per_loop_loss = ( + losses[k]["value"] * torch.tensor(losses[k]["weight"]).to(device) + ).mean() pbar_desc += f"{k}: {per_loop_loss:.3f} | " smpl_loss += per_loop_loss pbar_desc += f"Total: {smpl_loss:.3f}" @@ -279,19 +322,25 @@ if __name__ == "__main__": # save intermediate results / vis_freq and final_step if (i % args.vis_freq == 0) or (i == args.loop_smpl - 1): - per_loop_lst.extend([ - in_tensor["image"], - in_tensor["T_normal_F"], - in_tensor["normal_F"], - diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1), - ]) - per_loop_lst.extend([ - in_tensor["image"], - in_tensor["T_normal_B"], - in_tensor["normal_B"], - diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1), - ]) - per_data_lst.append(get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl")) + per_loop_lst.extend( + [ + in_tensor["image"], + in_tensor["T_normal_F"], + in_tensor["normal_F"], + diff_S[:, :, :512].unsqueeze(1).repeat(1, 3, 1, 1), + ] + ) + per_loop_lst.extend( + [ + in_tensor["image"], + in_tensor["T_normal_B"], + in_tensor["normal_B"], + diff_S[:, :, 512:].unsqueeze(1).repeat(1, 3, 1, 1), + ] + ) + per_data_lst.append( + get_optim_grid_image(per_loop_lst, None, nrow=N_body * 2, type="smpl") + ) smpl_loss.backward() optimizer_smpl.step() @@ -304,14 +353,21 @@ if __name__ == "__main__": img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png") torchvision.utils.save_image( torch.cat( - [data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5], - dim=3), img_crop_path) + [ + data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, + (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5 + ], + dim=3 + ), img_crop_path + ) rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data) rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data) img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png") - torchvision.utils.save_image(torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path) + torchvision.utils.save_image( + torch.cat([data["img_raw"], rgb_norm_F, rgb_norm_B], dim=-1) / 255., img_overlap_path + ) smpl_obj_lst = [] @@ -329,15 +385,28 @@ if __name__ == "__main__": if not osp.exists(smpl_obj_path): smpl_obj.export(smpl_obj_path) smpl_info = { - "betas": optimed_betas[idx].detach().cpu().unsqueeze(0), - "body_pose": rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach()).cpu().unsqueeze(0), - "global_orient": rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach()).cpu().unsqueeze(0), - "transl": optimed_trans[idx].detach().cpu(), - "expression": data["exp"][idx].cpu().unsqueeze(0), - "jaw_pose": rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0), - "left_hand_pose": rotation_matrix_to_angle_axis(data["left_hand_pose"][idx]).cpu().unsqueeze(0), - "right_hand_pose": rotation_matrix_to_angle_axis(data["right_hand_pose"][idx]).cpu().unsqueeze(0), - "scale": data["scale"][idx].cpu(), + "betas": + optimed_betas[idx].detach().cpu().unsqueeze(0), + "body_pose": + rotation_matrix_to_angle_axis(optimed_pose_mat[idx].detach() + ).cpu().unsqueeze(0), + "global_orient": + rotation_matrix_to_angle_axis(optimed_orient_mat[idx].detach() + ).cpu().unsqueeze(0), + "transl": + optimed_trans[idx].detach().cpu(), + "expression": + data["exp"][idx].cpu().unsqueeze(0), + "jaw_pose": + rotation_matrix_to_angle_axis(data["jaw_pose"][idx]).cpu().unsqueeze(0), + "left_hand_pose": + rotation_matrix_to_angle_axis(data["left_hand_pose"][idx] + ).cpu().unsqueeze(0), + "right_hand_pose": + rotation_matrix_to_angle_axis(data["right_hand_pose"][idx] + ).cpu().unsqueeze(0), + "scale": + data["scale"][idx].cpu(), } np.save( smpl_obj_path.replace(".obj", ".npy"), @@ -359,10 +428,13 @@ if __name__ == "__main__": per_data_lst = [] - batch_smpl_verts = in_tensor["smpl_verts"].detach() * torch.tensor([1.0, -1.0, 1.0], device=device) + batch_smpl_verts = in_tensor["smpl_verts"].detach( + ) * torch.tensor([1.0, -1.0, 1.0], device=device) batch_smpl_faces = in_tensor["smpl_faces"].detach()[:, :, [0, 2, 1]] - in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(batch_smpl_verts, batch_smpl_faces) + in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth( + batch_smpl_verts, batch_smpl_faces + ) per_loop_lst = [] @@ -389,7 +461,13 @@ if __name__ == "__main__": ) # BNI process - BNI_object = BNI(dir_path=osp.join(args.out_dir, cfg.name, "BNI"), name=data["name"], BNI_dict=BNI_dict, cfg=cfg.bni, device=device) + BNI_object = BNI( + dir_path=osp.join(args.out_dir, cfg.name, "BNI"), + name=data["name"], + BNI_dict=BNI_dict, + cfg=cfg.bni, + device=device + ) BNI_object.extract_surface(False) @@ -406,29 +484,40 @@ if __name__ == "__main__": side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask) # mesh completion via IF-net - in_tensor.update(dataset.depth_to_voxel({"depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B": BNI_object.B_depth.unsqueeze(0)})) + in_tensor.update( + dataset.depth_to_voxel( + { + "depth_F": BNI_object.F_depth.unsqueeze(0), + "depth_B": BNI_object.B_depth.unsqueeze(0) + } + ) + ) occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[ 0, ] * 3, scale=2.0).data.transpose(2, 1, 0) occupancies = np.flip(occupancies, axis=1) - in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device) + in_tensor["body_voxels"] = torch.tensor(occupancies.copy() + ).float().unsqueeze(0).to(device) with torch.no_grad(): - sdf = ifnet_model.reconEngine(netG=ifnet_model.netG, batch=in_tensor) - verts_IF, faces_IF = ifnet_model.reconEngine.export_mesh(sdf) + sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor) + verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf) - if ifnet_model.clean_mesh_flag: + if ifnet.clean_mesh_flag: verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF) side_mesh = trimesh.Trimesh(verts_IF, faces_IF) - side_mesh = remesh(side_mesh, side_mesh_path) + side_mesh = remesh_laplacian(side_mesh, side_mesh_path) else: side_mesh = apply_vertex_mask( side_mesh, - (SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + SMPLX_object.eyeball_vertex_mask).eq(0).float(), + ( + SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + + SMPLX_object.eyeball_vertex_mask + ).eq(0).float(), ) #register side_mesh to BNI surfaces @@ -448,7 +537,9 @@ if __name__ == "__main__": # 3. remove eyeball faces # export intermediate meshes - BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj") + BNI_object.F_B_trimesh.export( + f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj" + ) full_lst = [] if "face" in cfg.bni.use_smpl: @@ -458,37 +549,63 @@ if __name__ == "__main__": face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness]) # remove face neighbor triangles - BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face") - side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face") + BNI_object.F_B_trimesh = part_removal( + BNI_object.F_B_trimesh, + face_mesh, + cfg.bni.face_thres, + device, + smplx_mesh, + region="face" + ) + side_mesh = part_removal( + side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face" + ) face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj") full_lst += [face_mesh] if "hand" in cfg.bni.use_smpl and (True in data['hands_visibility'][idx]): - hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0],) + hand_mask = torch.zeros(SMPLX_object.smplx_verts.shape[0], ) if data['hands_visibility'][idx][0]: - hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0) + hand_mask.index_fill_( + 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["left_hand"]), 1.0 + ) if data['hands_visibility'][idx][1]: - hand_mask.index_fill_(0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0) + hand_mask.index_fill_( + 0, torch.tensor(SMPLX_object.smplx_mano_vid_dict["right_hand"]), 1.0 + ) # only hands hand_mesh = apply_vertex_mask(hand_mesh, hand_mask) # remove hand neighbor triangles - BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand") - side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand") + BNI_object.F_B_trimesh = part_removal( + BNI_object.F_B_trimesh, + hand_mesh, + cfg.bni.hand_thres, + device, + smplx_mesh, + region="hand" + ) + side_mesh = part_removal( + side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand" + ) hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj") full_lst += [hand_mesh] full_lst += [BNI_object.F_B_trimesh] # initial side_mesh could be SMPLX or IF-net - side_mesh = part_removal(side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False) + side_mesh = part_removal( + side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False + ) full_lst += [side_mesh] # # export intermediate meshes - BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj") + BNI_object.F_B_trimesh.export( + f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj" + ) side_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_side.obj") if cfg.bni.use_poisson: @@ -505,15 +622,22 @@ if __name__ == "__main__": rotate_recon_lst = dataset.render.get_image(cam_type="four") per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst) - # coloring the final mesh - final_colors = query_color( - torch.tensor(final_mesh.vertices).float(), - torch.tensor(final_mesh.faces).long(), - in_tensor["image"][idx:idx + 1], - device=device, - ) - final_mesh.visual.vertex_colors = final_colors - final_mesh.export(final_path) + if cfg.bni.texture_src == 'image': + + # coloring the final mesh (front: RGB pixels, back: normal colors) + final_colors = query_color( + torch.tensor(final_mesh.vertices).float(), + torch.tensor(final_mesh.faces).long(), + in_tensor["image"][idx:idx + 1], + device=device, + ) + final_mesh.visual.vertex_colors = final_colors + final_mesh.export(final_path) + + elif cfg.bni.texture_src == 'SD': + + # !TODO: add texture from Stable Diffusion + pass # for video rendering in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float()) diff --git a/apps/multi_render.py b/apps/multi_render.py index 933029cae4f98c2bc6400431dc3eb828701158ef..4088440757ce81137aaad7685d9df4b53b1c1383 100644 --- a/apps/multi_render.py +++ b/apps/multi_render.py @@ -20,6 +20,4 @@ faces_lst = in_tensor["body_faces"] + in_tensor["BNI_faces"] # self-rotated video render.load_meshes(verts_lst, faces_lst) -render.get_rendered_video_multi( - in_tensor, - f"{root}/{args.name}_cloth.mp4") \ No newline at end of file +render.get_rendered_video_multi(in_tensor, f"{root}/{args.name}_cloth.mp4") diff --git a/configs/econ.yaml b/configs/econ.yaml index 548b683af89750e58ce8536f19f99170aeb5e693..ec35721b8ac9cc25fb31a9bc7836d44f3373aeb4 100644 --- a/configs/econ.yaml +++ b/configs/econ.yaml @@ -35,3 +35,4 @@ bni: face_thres: 6e-2 thickness: 0.02 hps_type: "pixie" + texture_src: "SD" diff --git a/docs/installation.md b/docs/installation.md index e53298dec5f0e1cc89b2433f63ed23b9d97606dc..df52326db113c77165fdd254dd3e75aeb5f8a2f3 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -9,12 +9,11 @@ cd ECON ## Environment -- Ubuntu 20 / 18 -- GCC=7 (required by [pypoisson](https://github.com/mmolero/pypoisson/issues/13)) +- Ubuntu 20 / 18, (Windows as well, see [issue#7](https://github.com/YuliangXiu/ECON/issues/7)) - **CUDA=11.4, GPU Memory > 12GB** - Python = 3.8 - PyTorch >= 1.13.0 (official [Get Started](https://pytorch.org/get-started/locally/)) -- CUPY >= 11.3.0 (offcial [Installation](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi)) +- Cupy >= 11.3.0 (offcial [Installation](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi)) - PyTorch3D (official [INSTALL.md](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), recommend [install-from-local-clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone)) ```bash diff --git a/lib/common/BNI.py b/lib/common/BNI.py index df26ffb077682e8525f0bcf727520d181801a24a..1b65777913c2a3573848842f65dd266f213c7f87 100644 --- a/lib/common/BNI.py +++ b/lib/common/BNI.py @@ -1,12 +1,12 @@ -from lib.common.BNI_utils import (verts_inverse_transform, depth_inverse_transform, - double_side_bilateral_normal_integration) +from lib.common.BNI_utils import ( + verts_inverse_transform, depth_inverse_transform, double_side_bilateral_normal_integration +) import torch import trimesh class BNI: - def __init__(self, dir_path, name, BNI_dict, cfg, device): self.scale = 256.0 @@ -64,22 +64,20 @@ class BNI: F_B_verts = torch.cat((F_verts, B_verts), dim=0) F_B_faces = torch.cat( - (bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0) + (bni_result["F_faces"], bni_result["B_faces"] + bni_result["F_faces"].max() + 1), dim=0 + ) - self.F_B_trimesh = trimesh.Trimesh(F_B_verts.float(), - F_B_faces.long(), - process=False, - maintain_order=True) + self.F_B_trimesh = trimesh.Trimesh( + F_B_verts.float(), F_B_faces.long(), process=False, maintain_order=True + ) - self.F_trimesh = trimesh.Trimesh(F_verts.float(), - bni_result["F_faces"].long(), - process=False, - maintain_order=True) + self.F_trimesh = trimesh.Trimesh( + F_verts.float(), bni_result["F_faces"].long(), process=False, maintain_order=True + ) - self.B_trimesh = trimesh.Trimesh(B_verts.float(), - bni_result["B_faces"].long(), - process=False, - maintain_order=True) + self.B_trimesh = trimesh.Trimesh( + B_verts.float(), bni_result["B_faces"].long(), process=False, maintain_order=True + ) if __name__ == "__main__": @@ -93,16 +91,18 @@ if __name__ == "__main__": bni_dict = np.load(npy_file, allow_pickle=True).item() default_cfg = {'k': 2, 'lambda1': 1e-4, 'boundary_consist': 1e-6} - + # for k in [1, 2, 4, 10, 100]: # default_cfg['k'] = k # for k in [1e-8, 1e-4, 1e-2, 1e-1, 1]: - # default_cfg['lambda1'] = k + # default_cfg['lambda1'] = k # for k in [1e-4, 1e-2, 0]: - # default_cfg['boundary_consist'] = k - - bni_object = BNI(osp.dirname(npy_file), osp.basename(npy_file), bni_dict, default_cfg, - torch.device('cuda:0')) + # default_cfg['boundary_consist'] = k + + bni_object = BNI( + osp.dirname(npy_file), osp.basename(npy_file), bni_dict, default_cfg, + torch.device('cuda:0') + ) bni_object.extract_surface() bni_object.F_trimesh.export(osp.join(osp.dirname(npy_file), "F.obj")) diff --git a/lib/common/BNI_utils.py b/lib/common/BNI_utils.py index b5f5f873dbdfbaf9b40c268d481a6368824edbec..57deed1d9e7da0379a3017e4f0f44c1bf34314b9 100644 --- a/lib/common/BNI_utils.py +++ b/lib/common/BNI_utils.py @@ -53,8 +53,9 @@ def find_contour(mask, method='all'): contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) else: - contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, - cv2.CHAIN_APPROX_SIMPLE) + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) contour_cloth = np.array(find_max_list(contours))[:, 0, :] @@ -67,16 +68,19 @@ def mean_value_cordinates(inner_pts, contour_pts): body_edges_c = np.roll(body_edges_a, shift=-1, axis=1) body_edges_b = np.sqrt(((contour_pts - np.roll(contour_pts, shift=-1, axis=0))**2).sum(axis=1)) - body_edges = np.concatenate([ - body_edges_a[..., None], body_edges_c[..., None], - np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts)) - ], - axis=-1) + body_edges = np.concatenate( + [ + body_edges_a[..., None], body_edges_c[..., None], + np.repeat(body_edges_b[None, :, None], axis=0, repeats=len(inner_pts)) + ], + axis=-1 + ) body_cos = (body_edges[:, :, 0]**2 + body_edges[:, :, 1]**2 - body_edges[:, :, 2]**2) / (2 * body_edges[:, :, 0] * body_edges[:, :, 1]) body_tan_half = np.sqrt( - (1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.)) + (1. - np.clip(body_cos, a_max=1., a_min=-1.)) / np.clip(1. + body_cos, 1e-6, 2.) + ) w = (body_tan_half + np.roll(body_tan_half, shift=1, axis=1)) / body_edges_a w /= w.sum(axis=1, keepdims=True) @@ -97,16 +101,18 @@ def dispCorres(img_size, contour1, contour2, phi, dir_path): contour2 = contour2[None, :, None, :].astype(np.int32) disp = np.zeros((img_size, img_size, 3), dtype=np.uint8) - cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1) # green - cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1) # blue + cv2.drawContours(disp, contour1, -1, (0, 255, 0), 1) # green + cv2.drawContours(disp, contour2, -1, (255, 0, 0), 1) # blue - for i in range(contour1.shape[1]): # do not show all the points when display + for i in range(contour1.shape[1]): # do not show all the points when display # cv2.circle(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), 1, # (255, 0, 0), -1) corresPoint = contour2[0, phi[i], 0] # cv2.circle(disp, (corresPoint[0], corresPoint[1]), 1, (0, 255, 0), -1) - cv2.line(disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), - (corresPoint[0], corresPoint[1]), (255, 255, 255), 1) + cv2.line( + disp, (contour1[0, i, 0, 0], contour1[0, i, 0, 1]), (corresPoint[0], corresPoint[1]), + (255, 255, 255), 1 + ) cv2.imwrite(osp.join(dir_path, "corres.png"), disp) @@ -162,7 +168,8 @@ def verts_transform(t, depth_scale): t_copy *= depth_scale * 0.5 t_copy += depth_scale * 0.5 t_copy = t_copy[:, [1, 0, 2]] * torch.Tensor([2.0, 2.0, -2.0]) + torch.Tensor( - [0.0, 0.0, depth_scale]) + [0.0, 0.0, depth_scale] + ) return t_copy @@ -328,19 +335,22 @@ def construct_facets_from(mask): facet_move_top_mask = move_top(mask) facet_move_left_mask = move_left(mask) facet_move_top_left_mask = move_top_left(mask) - facet_top_left_mask = (facet_move_top_mask * facet_move_left_mask * facet_move_top_left_mask * - mask) + facet_top_left_mask = ( + facet_move_top_mask * facet_move_left_mask * facet_move_top_left_mask * mask + ) facet_top_right_mask = move_right(facet_top_left_mask) facet_bottom_left_mask = move_bottom(facet_top_left_mask) facet_bottom_right_mask = move_bottom_right(facet_top_left_mask) - return cp.hstack(( - 4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)), - idx[facet_top_left_mask][:, None], - idx[facet_bottom_left_mask][:, None], - idx[facet_bottom_right_mask][:, None], - idx[facet_top_right_mask][:, None], - )).astype(int) + return cp.hstack( + ( + 4 * cp.ones((cp.sum(facet_top_left_mask).item(), 1)), + idx[facet_top_left_mask][:, None], + idx[facet_bottom_left_mask][:, None], + idx[facet_bottom_right_mask][:, None], + idx[facet_top_right_mask][:, None], + ) + ).astype(int) def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1): @@ -364,8 +374,8 @@ def map_depth_map_to_point_clouds(depth_map, mask, K=None, step_size=1): u[..., 0] = xx u[..., 1] = yy u[..., 2] = 1 - u = u[mask].T # 3 x m - vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis] # m x 3 + u = u[mask].T # 3 x m + vertices = (cp.linalg.inv(K) @ u).T * depth_map[mask, cp.newaxis] # m x 3 return vertices @@ -374,7 +384,6 @@ def sigmoid(x, k=1): return 1 / (1 + cp.exp(-k * x)) - def boundary_excluded_mask(mask): top_mask = cp.pad(mask, ((1, 0), (0, 0)), "constant", constant_values=0)[:-1, :] bottom_mask = cp.pad(mask, ((0, 1), (0, 0)), "constant", constant_values=0)[1:, :] @@ -410,22 +419,24 @@ def create_boundary_matrix(mask): return B, B_full -def double_side_bilateral_normal_integration(normal_front, - normal_back, - normal_mask, - depth_front=None, - depth_back=None, - depth_mask=None, - k=2, - lambda_normal_back=1, - lambda_depth_front=1e-4, - lambda_depth_back=1e-2, - lambda_boundary_consistency=1, - step_size=1, - max_iter=150, - tol=1e-4, - cg_max_iter=5000, - cg_tol=1e-3): +def double_side_bilateral_normal_integration( + normal_front, + normal_back, + normal_mask, + depth_front=None, + depth_back=None, + depth_mask=None, + k=2, + lambda_normal_back=1, + lambda_depth_front=1e-4, + lambda_depth_back=1e-2, + lambda_boundary_consistency=1, + step_size=1, + max_iter=150, + tol=1e-4, + cg_max_iter=5000, + cg_tol=1e-3 +): # To avoid confusion, we list the coordinate systems in this code as follows # @@ -467,14 +478,12 @@ def double_side_bilateral_normal_integration(normal_front, del normal_map_back # right, left, top, bottom - A3_f, A4_f, A1_f, A2_f = generate_dx_dy(normal_mask, - nz_horizontal=nz_front, - nz_vertical=nz_front, - step_size=step_size) - A3_b, A4_b, A1_b, A2_b = generate_dx_dy(normal_mask, - nz_horizontal=nz_back, - nz_vertical=nz_back, - step_size=step_size) + A3_f, A4_f, A1_f, A2_f = generate_dx_dy( + normal_mask, nz_horizontal=nz_front, nz_vertical=nz_front, step_size=step_size + ) + A3_b, A4_b, A1_b, A2_b = generate_dx_dy( + normal_mask, nz_horizontal=nz_back, nz_vertical=nz_back, step_size=step_size + ) has_left_mask = cp.logical_and(move_right(normal_mask), normal_mask) has_right_mask = cp.logical_and(move_left(normal_mask), normal_mask) @@ -498,29 +507,25 @@ def double_side_bilateral_normal_integration(normal_front, b_back = cp.concatenate((-nx_back, -nx_back, -ny_back, -ny_back)) # initialization - W_front = spdiags(0.5 * cp.ones(4 * num_normals), - 0, - 4 * num_normals, - 4 * num_normals, - format="csr") - W_back = spdiags(0.5 * cp.ones(4 * num_normals), - 0, - 4 * num_normals, - 4 * num_normals, - format="csr") + W_front = spdiags( + 0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr" + ) + W_back = spdiags( + 0.5 * cp.ones(4 * num_normals), 0, 4 * num_normals, 4 * num_normals, format="csr" + ) z_front = cp.zeros(num_normals, float) z_back = cp.zeros(num_normals, float) z_combined = cp.concatenate((z_front, z_back)) B, B_full = create_boundary_matrix(normal_mask) - B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get()) #bug + B_mat = lambda_boundary_consistency * coo_matrix(B_full.get().T @ B_full.get()) #bug energy_list = [] if depth_mask is not None: - depth_mask_flat = depth_mask[normal_mask].astype(bool) # shape: (num_normals,) - z_prior_front = depth_map_front[normal_mask] # shape: (num_normals,) + depth_mask_flat = depth_mask[normal_mask].astype(bool) # shape: (num_normals,) + z_prior_front = depth_map_front[normal_mask] # shape: (num_normals,) z_prior_front[~depth_mask_flat] = 0 z_prior_back = depth_map_back[normal_mask] z_prior_back[~depth_mask_flat] = 0 @@ -554,40 +559,43 @@ def double_side_bilateral_normal_integration(normal_front, vstack((csr_matrix((num_normals, num_normals)), A_mat_back))]) + B_mat b_vec_combined = cp.concatenate((b_vec_front, b_vec_back)) - D = spdiags(1 / cp.clip(A_mat_combined.diagonal(), 1e-5, None), 0, 2 * num_normals, - 2 * num_normals, "csr") # Jacob preconditioner + D = spdiags( + 1 / cp.clip(A_mat_combined.diagonal(), 1e-5, None), 0, 2 * num_normals, 2 * num_normals, + "csr" + ) # Jacob preconditioner - z_combined, _ = cg(A_mat_combined, - b_vec_combined, - M=D, - x0=z_combined, - maxiter=cg_max_iter, - tol=cg_tol) + z_combined, _ = cg( + A_mat_combined, b_vec_combined, M=D, x0=z_combined, maxiter=cg_max_iter, tol=cg_tol + ) z_front = z_combined[:num_normals] z_back = z_combined[num_normals:] - wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k) # top - wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k) # right + wu_f = sigmoid((A2_f.dot(z_front))**2 - (A1_f.dot(z_front))**2, k) # top + wv_f = sigmoid((A4_f.dot(z_front))**2 - (A3_f.dot(z_front))**2, k) # right wu_f[top_boundnary_mask] = 0.5 wu_f[bottom_boundary_mask] = 0.5 wv_f[left_boundary_mask] = 0.5 wv_f[right_boudnary_mask] = 0.5 - W_front = spdiags(cp.concatenate((wu_f, 1 - wu_f, wv_f, 1 - wv_f)), - 0, - 4 * num_normals, - 4 * num_normals, - format="csr") - - wu_b = sigmoid((A2_b.dot(z_back))**2 - (A1_b.dot(z_back))**2, k) # top - wv_b = sigmoid((A4_b.dot(z_back))**2 - (A3_b.dot(z_back))**2, k) # right + W_front = spdiags( + cp.concatenate((wu_f, 1 - wu_f, wv_f, 1 - wv_f)), + 0, + 4 * num_normals, + 4 * num_normals, + format="csr" + ) + + wu_b = sigmoid((A2_b.dot(z_back))**2 - (A1_b.dot(z_back))**2, k) # top + wv_b = sigmoid((A4_b.dot(z_back))**2 - (A3_b.dot(z_back))**2, k) # right wu_b[top_boundnary_mask] = 0.5 wu_b[bottom_boundary_mask] = 0.5 wv_b[left_boundary_mask] = 0.5 wv_b[right_boudnary_mask] = 0.5 - W_back = spdiags(cp.concatenate((wu_b, 1 - wu_b, wv_b, 1 - wv_b)), - 0, - 4 * num_normals, - 4 * num_normals, - format="csr") + W_back = spdiags( + cp.concatenate((wu_b, 1 - wu_b, wv_b, 1 - wv_b)), + 0, + 4 * num_normals, + 4 * num_normals, + format="csr" + ) energy_old = energy energy = (A_front_data @ z_front - b_front).T @ W_front @ (A_front_data @ z_front - b_front) + \ @@ -603,23 +611,26 @@ def double_side_bilateral_normal_integration(normal_front, if relative_energy < tol: break # del A1, A2, A3, A4, nx, ny - + depth_map_front_est = cp.ones_like(normal_mask, float) * cp.nan depth_map_front_est[normal_mask] = z_front depth_map_back_est = cp.ones_like(normal_mask, float) * cp.nan depth_map_back_est[normal_mask] = z_back - + # manually cut the intersection - normal_mask[depth_map_front_est>=depth_map_back_est] = False + normal_mask[depth_map_front_est >= depth_map_back_est] = False depth_map_front_est[~normal_mask] = cp.nan depth_map_back_est[~normal_mask] = cp.nan vertices_front = cp.asnumpy( - map_depth_map_to_point_clouds(depth_map_front_est, normal_mask, K=None, - step_size=step_size)) + map_depth_map_to_point_clouds( + depth_map_front_est, normal_mask, K=None, step_size=step_size + ) + ) vertices_back = cp.asnumpy( - map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size)) + map_depth_map_to_point_clouds(depth_map_back_est, normal_mask, K=None, step_size=step_size) + ) facets_back = cp.asnumpy(construct_facets_from(normal_mask)) @@ -656,7 +667,7 @@ def save_normal_tensor(in_tensor, idx, png_path, thickness=0.0): depth_B_arr = depth2arr(in_tensor["depth_B"][idx]) BNI_dict = {} - + # clothed human BNI_dict["normal_F"] = normal_F_arr BNI_dict["normal_B"] = normal_B_arr diff --git a/lib/common/blender_utils.py b/lib/common/blender_utils.py index 45b443f2157d712c8be145458bf4f3197b727521..a02260cc722bd9729dfbeb153543ac5f648deacf 100644 --- a/lib/common/blender_utils.py +++ b/lib/common/blender_utils.py @@ -3,6 +3,7 @@ import sys, os from math import radians import mathutils import bmesh + print(sys.exec_prefix) from tqdm import tqdm import numpy as np @@ -29,7 +30,6 @@ shadows = False # diffuse_color = (18/255., 139/255., 142/255.,1) #correct # diffuse_color = (251/255., 60/255., 60/255.,1) #wrong - smooth = False wireframe = False @@ -47,13 +47,16 @@ compositor_alpha = 0.7 # Helper functions ################################################## + def blender_print(*args, **kwargs): - print (*args, **kwargs, file=sys.stderr) + print(*args, **kwargs, file=sys.stderr) + def using_app(): ''' Returns if script is running through Blender application (GUI or background processing)''' return (not sys.argv[0].endswith('.py')) + def setup_diffuse_transparent_material(target, color, object_transparent, backface_transparent): ''' Sets up diffuse/transparent material with backface culling in cycles''' @@ -110,8 +113,10 @@ def setup_diffuse_transparent_material(target, color, object_transparent, backfa links.new(node_mix_backface.outputs[0], node_output.inputs[0]) return + ################################################## + def setup_scene(): global render global cycles_gpu @@ -150,12 +155,13 @@ def setup_scene(): if cycles_gpu: print('Activating GPU acceleration') bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA' - + if bpy.app.version[0] >= 3: - cuda_devices = bpy.context.preferences.addons['cycles'].preferences.get_devices_for_type(compute_device_type = 'CUDA') + cuda_devices = bpy.context.preferences.addons[ + 'cycles'].preferences.get_devices_for_type(compute_device_type='CUDA') else: - (cuda_devices, opencl_devices) = bpy.context.preferences.addons['cycles'].preferences.get_devices() - + (cuda_devices, opencl_devices + ) = bpy.context.preferences.addons['cycles'].preferences.get_devices() if (len(cuda_devices) < 1): print('ERROR: CUDA GPU acceleration not available') @@ -178,7 +184,7 @@ def setup_scene(): if bpy.app.version[0] < 3: scene.render.tile_x = 64 scene.render.tile_y = 64 - + # Disable Blender 3 denoiser to properly measure Cycles render speed if bpy.app.version[0] >= 3: scene.cycles.use_denoising = False @@ -226,7 +232,6 @@ def setup_scene(): bpy.ops.mesh.mark_freestyle_edge(clear=True) bpy.ops.object.mode_set(mode='OBJECT') - # Setup freestyle mode for wireframe overlay rendering if wireframe: scene.render.use_freestyle = True @@ -245,8 +250,10 @@ def setup_scene(): # Output transparent image when no background is used scene.render.image_settings.color_mode = 'RGBA' + ################################################## + def setup_compositing(): global compositor_image_scale @@ -275,6 +282,7 @@ def setup_compositing(): links.new(blend_node.outputs[0], tree.nodes['Composite'].inputs[0]) + def render_file(input_file, input_dir, output_file, output_dir, yaw, correct): '''Render image of given model file''' global smooth @@ -288,13 +296,13 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct): # Import object into scene bpy.ops.import_scene.obj(filepath=path) object = bpy.context.selected_objects[0] - + object.rotation_euler = (radians(90.0), 0.0, radians(yaw)) - z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:,1]) + z_bottom = np.min(np.array([vert.co for vert in object.data.vertices])[:, 1]) # z_top = np.max(np.array([vert.co for vert in object.data.vertices])[:,1]) # blender_print(radians(90.0), z_bottom, z_top) object.location -= mathutils.Vector((0.0, 0.0, z_bottom)) - + if quads: bpy.context.view_layer.objects.active = object bpy.ops.object.mode_set(mode='EDIT') @@ -309,11 +317,11 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct): bpy.ops.object.mode_set(mode='EDIT') bpy.ops.mesh.mark_freestyle_edge(clear=False) bpy.ops.object.mode_set(mode='OBJECT') - + if correct: - diffuse_color = (18/255., 139/255., 142/255.,1) #correct + diffuse_color = (18 / 255., 139 / 255., 142 / 255., 1) #correct else: - diffuse_color = (251/255., 60/255., 60/255.,1) #wrong + diffuse_color = (251 / 255., 60 / 255., 60 / 255., 1) #wrong setup_diffuse_transparent_material(object, diffuse_color, object_transparent, mouth_transparent) @@ -336,10 +344,10 @@ def render_file(input_file, input_dir, output_file, output_dir, yaw, correct): bpy.ops.render.render(write_still=True) # Remove temporary output redirection -# sys.stdout.flush() -# os.close(1) -# os.dup(old) -# os.close(old) + # sys.stdout.flush() + # os.close(1) + # os.dup(old) + # os.close(old) # Delete last selected object from scene object.select_set(True) @@ -351,7 +359,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True): global quality_preview if not input_file.endswith('.obj'): - print('ERROR: Invalid input: ' + input_file ) + print('ERROR: Invalid input: ' + input_file) return print('Processing: ' + input_file) @@ -361,7 +369,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True): if quality_preview: output_file = output_file.replace('.png', '-preview.png') - angle = 360.0/views + angle = 360.0 / views pbar = tqdm(range(0, views)) for view in pbar: pbar.set_description(f"{os.path.basename(output_file)} | View:{str(view)}") @@ -369,8 +377,7 @@ def process_file(input_file, input_dir, output_file, output_dir, correct=True): output_file_view = f"{output_file}/{view:03d}.png" if not os.path.exists(os.path.join(output_dir, output_file_view)): render_file(input_file, input_dir, output_file_view, output_dir, yaw, correct) - + cmd = "ffmpeg -loglevel quiet -r 30 -f lavfi -i color=c=white:s=512x512 -i " + os.path.join(output_dir, output_file, '%3d.png') + \ " -shortest -filter_complex \"[0:v][1:v]overlay=shortest=1,format=yuv420p[out]\" -map \"[out]\" -y " + output_dir+"/"+output_file+".mp4" os.system(cmd) - \ No newline at end of file diff --git a/lib/common/cloth_extraction.py b/lib/common/cloth_extraction.py index 7da5f0ec1102f49ff513af27e08597b0bd65bcb7..612a96787e1aa836b097971e7aaf55b284ef178a 100644 --- a/lib/common/cloth_extraction.py +++ b/lib/common/cloth_extraction.py @@ -36,11 +36,13 @@ def load_segmentation(path, shape): xy = np.vstack((x, y)).T coordinates.append(xy) - segmentations.append({ - "type": val["category_name"], - "type_id": val["category_id"], - "coordinates": coordinates, - }) + segmentations.append( + { + "type": val["category_name"], + "type_id": val["category_id"], + "coordinates": coordinates, + } + ) return segmentations @@ -56,9 +58,8 @@ def smpl_to_recon_labels(recon, smpl, k=1): Returns a dictionary containing the bodypart and the corresponding indices """ smpl_vert_segmentation = json.load( - open( - os.path.join(os.path.dirname(__file__), - "smpl_vert_segmentation.json"))) + open(os.path.join(os.path.dirname(__file__), "smpl_vert_segmentation.json")) + ) n = smpl.vertices.shape[0] y = np.array([None] * n) for key, val in smpl_vert_segmentation.items(): @@ -71,8 +72,7 @@ def smpl_to_recon_labels(recon, smpl, k=1): recon_labels = {} for key in smpl_vert_segmentation.keys(): - recon_labels[key] = list( - np.argwhere(y_pred == key).flatten().astype(int)) + recon_labels[key] = list(np.argwhere(y_pred == key).flatten().astype(int)) return recon_labels @@ -139,8 +139,7 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None): if type == 1 or type == 3 or type == 10: body_parts_to_remove += ["leftForeArm", "rightForeArm"] # No sleeves at all or lower body clothes - elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 - or type == 9): + elif (type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9): body_parts_to_remove += [ "leftForeArm", "rightForeArm", @@ -159,8 +158,8 @@ def extract_cloth(recon, segmentation, K, R, t, smpl=None): ] verts_to_remove = list( - itertools.chain.from_iterable( - [recon_labels[part] for part in body_parts_to_remove])) + itertools.chain.from_iterable([recon_labels[part] for part in body_parts_to_remove]) + ) label_mask = np.zeros(num_verts, dtype=bool) label_mask[verts_to_remove] = True diff --git a/lib/common/config.py b/lib/common/config.py index 04a65599f0e32aaa95e007aea1aa106a5c58a868..33c917cc61fce59cccf1f5e4a34056a0f22fc0e3 100644 --- a/lib/common/config.py +++ b/lib/common/config.py @@ -100,6 +100,7 @@ _C.bni.thickness = 0.00 _C.bni.hand_thres = 4e-2 _C.bni.face_thres = 6e-2 _C.bni.hps_type = "pixie" +_C.bni.texture_src = "image" # kernel_size, stride, dilation, padding @@ -170,10 +171,10 @@ _C.dataset.rp_type = "pifu900" _C.dataset.th_type = "train" _C.dataset.input_size = 512 _C.dataset.rotation_num = 3 -_C.dataset.num_precomp = 10 # Number of segmentation classifiers -_C.dataset.num_multiseg = 500 # Number of categories per classifier -_C.dataset.num_knn = 10 # for loss/error -_C.dataset.num_knn_dis = 20 # for accuracy +_C.dataset.num_precomp = 10 # Number of segmentation classifiers +_C.dataset.num_multiseg = 500 # Number of categories per classifier +_C.dataset.num_knn = 10 # for loss/error +_C.dataset.num_knn_dis = 20 # for accuracy _C.dataset.num_verts_max = 20000 _C.dataset.zray_type = False _C.dataset.online_smpl = False @@ -210,8 +211,7 @@ def get_cfg_defaults(): # Alternatively, provide a way to import the defaults as # a global singleton: -cfg = _C # users can `from config import cfg` - +cfg = _C # users can `from config import cfg` # cfg = get_cfg_defaults() # cfg.merge_from_file('./configs/example.yaml') @@ -244,9 +244,7 @@ def parse_args(args): def parse_args_extend(args): if args.resume: if not os.path.exists(args.log_dir): - raise ValueError( - "Experiment are set to resume mode, but log directory does not exist." - ) + raise ValueError("Experiment are set to resume mode, but log directory does not exist.") # load log's cfg cfg_file = os.path.join(args.log_dir, "cfg.yaml") diff --git a/lib/common/imutils.py b/lib/common/imutils.py index cc9e09f888ffc2b268308d0a1802debf798db0cd..f96e666b8a80d139aa240275a33f06d1464a6207 100644 --- a/lib/common/imutils.py +++ b/lib/common/imutils.py @@ -3,14 +3,13 @@ import mediapipe as mp import torch import numpy as np import torch.nn.functional as F -from rembg import remove -from rembg.session_factory import new_session from PIL import Image -from torchvision.models import detection - from lib.pymafx.core import constants -from lib.common.cloth_extraction import load_segmentation + +from rembg import remove +from rembg.session_factory import new_session from torchvision import transforms +from kornia.geometry.transform import get_affine_matrix2d, warp_affine def transform_to_tensor(res, mean=None, std=None, is_tensor=False): @@ -24,42 +23,40 @@ def transform_to_tensor(res, mean=None, std=None, is_tensor=False): return transforms.Compose(all_ops) -def aug_matrix(w1, h1, w2, h2): - dx = (w2 - w1) / 2.0 - dy = (h2 - h1) / 2.0 - - matrix_trans = np.array([[1.0, 0, dx], [0, 1.0, dy], [0, 0, 1.0]]) - - scale = np.min([float(w2) / w1, float(h2) / h1]) +def get_affine_matrix_wh(w1, h1, w2, h2): - M = get_affine_matrix(center=(w2 / 2.0, h2 / 2.0), translate=(0, 0), scale=scale) - - M = np.array(M + [0.0, 0.0, 1.0]).reshape(3, 3) - M = M.dot(matrix_trans) + transl = torch.tensor([(w2 - w1) / 2.0, (h2 - h1) / 2.0]).unsqueeze(0) + center = torch.tensor([w1 / 2.0, h1 / 2.0]).unsqueeze(0) + scale = torch.min(torch.tensor([w2 / w1, h2 / h1])).repeat(2).unsqueeze(0) + M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.])) return M -def get_affine_matrix(center, translate, scale): - cx, cy = center - tx, ty = translate - - M = [1, 0, 0, 0, 1, 0] - M = [x * scale for x in M] +def get_affine_matrix_box(boxes, w2, h2): - # Apply translation and of center translation: RSS * C^-1 - M[2] += M[0] * (-cx) + M[1] * (-cy) - M[5] += M[3] * (-cx) + M[4] * (-cy) + # boxes [left, top, right, bottom] + width = boxes[:, 2] - boxes[:, 0] #(N,) + height = boxes[:, 3] - boxes[:, 1] #(N,) + center = torch.tensor( + [(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0] + ).T #(N,2) + scale = torch.min(torch.tensor([w2 / width, h2 / height]), + dim=0)[0].unsqueeze(1).repeat(1, 2) * 0.9 #(N,2) + transl = torch.tensor([w2 / 2.0 - center[:, 0], h2 / 2.0 - center[:, 1]]).unsqueeze(0) #(N,2) + M = get_affine_matrix2d(transl, center, scale, angle=torch.tensor([0.])) - # Apply center translation: T * C * RSS * C^-1 - M[2] += cx + tx - M[5] += cy + ty return M def load_img(img_file): img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED) + + # considering 16-bit image + if img.dtype == np.uint16: + img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) @@ -68,11 +65,10 @@ def load_img(img_file): else: img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) - return img + return torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float(), img.shape[:2] def get_keypoints(image): - def collect_xyv(x, body=True): lmk = x.landmark all_lmks = [] @@ -84,8 +80,8 @@ def get_keypoints(image): mp_holistic = mp.solutions.holistic with mp_holistic.Holistic( - static_image_mode=True, - model_complexity=2, + static_image_mode=True, + model_complexity=2, ) as holistic: results = holistic.process(image) @@ -93,9 +89,15 @@ def get_keypoints(image): result = {} result["body"] = collect_xyv(results.pose_landmarks) if results.pose_landmarks else fake_kps - result["lhand"] = collect_xyv(results.left_hand_landmarks, False) if results.left_hand_landmarks else fake_kps - result["rhand"] = collect_xyv(results.right_hand_landmarks, False) if results.right_hand_landmarks else fake_kps - result["face"] = collect_xyv(results.face_landmarks, False) if results.face_landmarks else fake_kps + result["lhand"] = collect_xyv( + results.left_hand_landmarks, False + ) if results.left_hand_landmarks else fake_kps + result["rhand"] = collect_xyv( + results.right_hand_landmarks, False + ) if results.right_hand_landmarks else fake_kps + result["face"] = collect_xyv( + results.face_landmarks, False + ) if results.face_landmarks else fake_kps return result @@ -104,13 +106,21 @@ def get_pymafx(image, landmarks): # image [3,512,512] - item = {'img_body': F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0]} + item = { + 'img_body': + F.interpolate(image.unsqueeze(0), size=224, mode='bicubic', align_corners=True)[0] + } for part in ['lhand', 'rhand', 'face']: kp2d = landmarks[part] kp2d_valid = kp2d[kp2d[:, 3] > 0.] if len(kp2d_valid) > 0: - bbox = [min(kp2d_valid[:, 0]), min(kp2d_valid[:, 1]), max(kp2d_valid[:, 0]), max(kp2d_valid[:, 1])] + bbox = [ + min(kp2d_valid[:, 0]), + min(kp2d_valid[:, 1]), + max(kp2d_valid[:, 0]), + max(kp2d_valid[:, 1]) + ] center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.] scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 @@ -141,20 +151,6 @@ def get_pymafx(image, landmarks): return item -def expand_bbox(bbox, width, height, ratio=0.1): - - bbox = np.around(bbox).astype(np.int16) - bbox_width = bbox[2] - bbox[0] - bbox_height = bbox[3] - bbox[1] - - bbox[1] = max(bbox[1] - bbox_height * ratio, 0) - bbox[3] = min(bbox[3] + bbox_height * ratio, height) - bbox[0] = max(bbox[0] - bbox_width * ratio, 0) - bbox[2] = min(bbox[2] + bbox_width * ratio, width) - - return bbox - - def remove_floats(mask): # 1. find all the contours @@ -173,51 +169,48 @@ def remove_floats(mask): return new_mask -def process_image(img_file, hps_type, single, input_res=512): +def process_image(img_file, hps_type, single, input_res, detector): - img_raw = load_img(img_file) - - in_height, in_width = img_raw.shape[:2] - M = aug_matrix(in_width, in_height, input_res * 2, input_res * 2) - - # from rectangle to square by padding (input_res*2, input_res*2) - img_square = cv2.warpAffine(img_raw, M[0:2, :], (input_res * 2, input_res * 2), flags=cv2.INTER_CUBIC) + img_raw, (in_height, in_width) = load_img(img_file) + tgt_res = input_res * 2 + M_square = get_affine_matrix_wh(in_width, in_height, tgt_res, tgt_res) + img_square = warp_affine( + img_raw, + M_square[:, :2], (tgt_res, ) * 2, + mode='bilinear', + padding_mode='zeros', + align_corners=True + ) # detection for bbox - detector = detection.maskrcnn_resnet50_fpn(weights=detection.MaskRCNN_ResNet50_FPN_V2_Weights) - detector.eval() - predictions = detector([torch.from_numpy(img_square).permute(2, 0, 1) / 255.])[0] + predictions = detector(img_square / 255.)[0] if single: top_score = predictions["scores"][predictions["labels"] == 1].max() human_ids = torch.where(predictions["scores"] == top_score)[0] else: - human_ids = torch.logical_and(predictions["labels"] == 1, predictions["scores"] > 0.9).nonzero().squeeze(1) + human_ids = torch.logical_and(predictions["labels"] == 1, + predictions["scores"] > 0.9).nonzero().squeeze(1) boxes = predictions["boxes"][human_ids, :].detach().cpu().numpy() masks = predictions["masks"][human_ids, :, :].permute(0, 2, 3, 1).detach().cpu().numpy() - width = boxes[:, 2] - boxes[:, 0] #(N,) - height = boxes[:, 3] - boxes[:, 1] #(N,) - center = np.array([(boxes[:, 0] + boxes[:, 2]) / 2.0, (boxes[:, 1] + boxes[:, 3]) / 2.0]).T #(N,2) - scale = np.array([width, height]).max(axis=0) / 90. + M_crop = get_affine_matrix_box(boxes, input_res, input_res) img_icon_lst = [] img_crop_lst = [] img_hps_lst = [] img_mask_lst = [] - uncrop_param_lst = [] landmark_lst = [] hands_visibility_lst = [] img_pymafx_lst = [] uncrop_param = { - "center": center, - "scale": scale, "ori_shape": [in_height, in_width], "box_shape": [input_res, input_res], - "crop_shape": [input_res * 2, input_res * 2, 3], - "M": M, + "square_shape": [tgt_res, tgt_res], + "M_square": M_square, + "M_crop": M_crop } for idx in range(len(boxes)): @@ -228,59 +221,74 @@ def process_image(img_file, hps_type, single, input_res=512): else: mask_detection = masks[0] * 0. - img_crop, _ = crop( - np.concatenate([img_square, (mask_detection < 0.4) * 255], axis=2), center[idx], scale[idx], [input_res, input_res]) - - # get accurate segmentation mask of focus person + img_square_rgba = torch.cat( + [img_square.squeeze(0).permute(1, 2, 0), + torch.tensor(mask_detection < 0.4) * 255], + dim=2 + ) + + img_crop = warp_affine( + img_square_rgba.unsqueeze(0).permute(0, 3, 1, 2), + M_crop[idx:idx + 1, :2], (input_res, ) * 2, + mode='bilinear', + padding_mode='zeros', + align_corners=True + ).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) + + # get accurate person segmentation mask img_rembg = remove(img_crop, post_process_mask=True, session=new_session("u2net")) img_mask = remove_floats(img_rembg[:, :, [3]]) - # required image tensors / arrays - - # img_icon (tensor): (-1, 1), [3,512,512] - # img_hps (tensor): (-2.11, 2.44), [3,224,224] - - # img_np (array): (0, 255), [512,512,3] - # img_rembg (array): (0, 255), [512,512,4] - # img_mask (array): (0, 1), [512,512,1] - # img_crop (array): (0, 255), [512,512,4] - mean_icon = std_icon = (0.5, 0.5, 0.5) img_np = (img_rembg[..., :3] * img_mask).astype(np.uint8) - img_icon = transform_to_tensor(512, mean_icon, std_icon)(Image.fromarray(img_np)) * torch.tensor(img_mask).permute( - 2, 0, 1) - img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN, constants.IMG_NORM_STD)(Image.fromarray(img_np)) + img_icon = transform_to_tensor(512, mean_icon, std_icon)( + Image.fromarray(img_np) + ) * torch.tensor(img_mask).permute(2, 0, 1) + img_hps = transform_to_tensor(224, constants.IMG_NORM_MEAN, + constants.IMG_NORM_STD)(Image.fromarray(img_np)) landmarks = get_keypoints(img_np) + # get hands visibility + hands_visibility = [True, True] + if landmarks['lhand'][:, -1].mean() == 0.: + hands_visibility[0] = False + if landmarks['rhand'][:, -1].mean() == 0.: + hands_visibility[1] = False + hands_visibility_lst.append(hands_visibility) + if hps_type == 'pymafx': img_pymafx_lst.append( get_pymafx( - transform_to_tensor(512, constants.IMG_NORM_MEAN, constants.IMG_NORM_STD)(Image.fromarray(img_np)), - landmarks)) + transform_to_tensor(512, constants.IMG_NORM_MEAN, + constants.IMG_NORM_STD)(Image.fromarray(img_np)), landmarks + ) + ) img_crop_lst.append(torch.tensor(img_crop).permute(2, 0, 1) / 255.0) img_icon_lst.append(img_icon) img_hps_lst.append(img_hps) img_mask_lst.append(torch.tensor(img_mask[..., 0])) - uncrop_param_lst.append(uncrop_param) landmark_lst.append(landmarks['body']) - hands_visibility = [True, True] - if landmarks['lhand'][:, -1].mean() == 0.: - hands_visibility[0] = False - if landmarks['rhand'][:, -1].mean() == 0.: - hands_visibility[1] = False - hands_visibility_lst.append(hands_visibility) + # required image tensors / arrays + + # img_icon (tensor): (-1, 1), [3,512,512] + # img_hps (tensor): (-2.11, 2.44), [3,224,224] + + # img_np (array): (0, 255), [512,512,3] + # img_rembg (array): (0, 255), [512,512,4] + # img_mask (array): (0, 1), [512,512,1] + # img_crop (array): (0, 255), [512,512,4] return_dict = { - "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res] - "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res] - "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res] - "img_raw": img_raw, #[H, W, 3] - "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res] + "img_icon": torch.stack(img_icon_lst).float(), #[N, 3, res, res] + "img_crop": torch.stack(img_crop_lst).float(), #[N, 4, res, res] + "img_hps": torch.stack(img_hps_lst).float(), #[N, 3, res, res] + "img_raw": img_raw, #[1, 3, H, W] + "img_mask": torch.stack(img_mask_lst).float(), #[N, res, res] "uncrop_param": uncrop_param, - "landmark": torch.stack(landmark_lst), #[N, 33, 4] + "landmark": torch.stack(landmark_lst), #[N, 33, 4] "hands_visibility": hands_visibility_lst, } @@ -302,250 +310,51 @@ def process_image(img_file, hps_type, single, input_res=512): return return_dict -def get_transform(center, scale, res): - """Generate transformation matrix.""" - h = 100 * scale - t = np.zeros((3, 3)) - t[0, 0] = float(res[1]) / h - t[1, 1] = float(res[0]) / h - t[0, 2] = res[1] * (-float(center[0]) / h + 0.5) - t[1, 2] = res[0] * (-float(center[1]) / h + 0.5) - t[2, 2] = 1 - - return t - - -def transform(pt, center, scale, res, invert=0): - """Transform pixel location to different reference.""" - t = get_transform(center, scale, res) - if invert: - t = np.linalg.inv(t) - new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T - new_pt = np.dot(t, new_pt) - return np.around(new_pt[:2]).astype(np.int16) - - -def crop(img, center, scale, res): - """Crop image according to the supplied bounding box.""" - - img_height, img_width = img.shape[:2] - - # Upper left point - ul = np.array(transform([0, 0], center, scale, res, invert=1)) - - # Bottom right point - br = np.array(transform(res, center, scale, res, invert=1)) - - new_shape = [br[1] - ul[1], br[0] - ul[0]] - if len(img.shape) > 2: - new_shape += [img.shape[2]] - new_img = np.zeros(new_shape) - - # Range to fill new array - new_x = max(0, -ul[0]), min(br[0], img_width) - ul[0] - new_y = max(0, -ul[1]), min(br[1], img_height) - ul[1] - - # Range to sample from original image - old_x = max(0, ul[0]), min(img_width, br[0]) - old_y = max(0, ul[1]), min(img_height, br[1]) - - new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] - new_img = F.interpolate( - torch.tensor(new_img).permute(2, 0, 1).unsqueeze(0), res, mode='bilinear').permute(0, 2, 3, - 1)[0].numpy().astype(np.uint8) - - return new_img, (old_x, new_x, old_y, new_y, new_shape) - - -def crop_segmentation(org_coord, res, cropping_parameters): - old_x, new_x, old_y, new_y, new_shape = cropping_parameters +def blend_rgb_norm(norms, data): - new_coord = np.zeros((org_coord.shape)) - new_coord[:, 0] = new_x[0] + (org_coord[:, 0] - old_x[0]) - new_coord[:, 1] = new_y[0] + (org_coord[:, 1] - old_y[0]) - - new_coord[:, 0] = res[0] * (new_coord[:, 0] / new_shape[1]) - new_coord[:, 1] = res[1] * (new_coord[:, 1] / new_shape[0]) - - return new_coord - - -def corner_align(ul, br): - - if ul[1] - ul[0] != br[1] - br[0]: - ul[1] = ul[0] + br[1] - br[0] - - return ul, br - - -def uncrop(img, center, scale, orig_shape): - """'Undo' the image cropping/resizing. - This function is used when evaluating mask/part segmentation. - """ - - res = img.shape[:2] - - # Upper left point - ul = np.array(transform([0, 0], center, scale, res, invert=1)) - # Bottom right point - br = np.array(transform(res, center, scale, res, invert=1)) - - # quick fix - ul, br = corner_align(ul, br) - - # size of cropped image - crop_shape = [br[1] - ul[1], br[0] - ul[0]] - new_img = np.zeros(orig_shape, dtype=np.uint8) - - # Range to fill new array - new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] - new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] + # norms [N, 3, res, res] + masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1) + norm_mask = F.interpolate( + torch.cat([norms, masks], dim=1).detach(), + size=data["uncrop_param"]["box_shape"], + mode="bilinear", + align_corners=False + ) + final = data["img_raw"].type_as(norm_mask) - # Range to sample from original image - old_x = max(0, ul[0]), min(orig_shape[1], br[0]) - old_y = max(0, ul[1]), min(orig_shape[0], br[1]) + for idx in range(len(norms)): - img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape)) + norm_pred = (norm_mask[idx:idx + 1, :3, :, :] + 1.0) * 255.0 / 2.0 + mask_pred = norm_mask[idx:idx + 1, 3:4, :, :].repeat(1, 3, 1, 1) - new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] + norm_ori = unwrap(norm_pred, data["uncrop_param"], idx) + mask_ori = unwrap(mask_pred, data["uncrop_param"], idx) - return new_img + final = final * (1.0 - mask_ori) + norm_ori * mask_ori + return final.detach().cpu() -def rot_aa(aa, rot): - """Rotate axis angle parameters.""" - # pose parameters - R = np.array([ - [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], - [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], - [0, 0, 1], - ]) - # find the rotation of the body in camera frame - per_rdg, _ = cv2.Rodrigues(aa) - # apply the global rotation to the global orientation - resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg)) - aa = (resrot.T)[0] - return aa +def unwrap(image, uncrop_param, idx): -def flip_img(img): - """Flip rgb images or masks. - channels come last, e.g. (256,256,3). - """ - img = np.fliplr(img) - return img + device = image.device + img_square = warp_affine( + image, + torch.inverse(uncrop_param["M_crop"])[idx:idx + 1, :2].to(device), + uncrop_param["square_shape"], + mode='bilinear', + padding_mode='zeros', + align_corners=True + ) -def flip_kp(kp, is_smpl=False): - """Flip keypoints.""" - if len(kp) == 24: - if is_smpl: - flipped_parts = constants.SMPL_JOINTS_FLIP_PERM - else: - flipped_parts = constants.J24_FLIP_PERM - elif len(kp) == 49: - if is_smpl: - flipped_parts = constants.SMPL_J49_FLIP_PERM - else: - flipped_parts = constants.J49_FLIP_PERM - kp = kp[flipped_parts] - kp[:, 0] = -kp[:, 0] - return kp - - -def flip_pose(pose): - """Flip pose. - The flipping is based on SMPL parameters. - """ - flipped_parts = constants.SMPL_POSE_FLIP_PERM - pose = pose[flipped_parts] - # we also negate the second and the third dimension of the axis-angle - pose[1::3] = -pose[1::3] - pose[2::3] = -pose[2::3] - return pose - - -def normalize_2d_kp(kp_2d, crop_size=224, inv=False): - # Normalize keypoints between -1, 1 - if not inv: - ratio = 1.0 / crop_size - kp_2d = 2.0 * kp_2d * ratio - 1.0 - else: - ratio = 1.0 / crop_size - kp_2d = (kp_2d + 1.0) / (2 * ratio) - - return kp_2d - - -def visualize_landmarks(image, joints, color): - - img_w, img_h = image.shape[:2] - - for joint in joints: - image = cv2.circle(image, (int(joint[0] * img_w), int(joint[1] * img_h)), 5, color) - - return image - - -def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None): - """ - param joints: [num_joints, 3] - param joints_vis: [num_joints, 3] - return: target, target_weight(1: visible, 0: invisible) - """ - num_joints = joints.shape[0] - device = joints.device - cur_device = torch.device(device.type, device.index) - if not hasattr(heatmap_size, "__len__"): - # width height - heatmap_size = [heatmap_size, heatmap_size] - assert len(heatmap_size) == 2 - target_weight = np.ones((num_joints, 1), dtype=np.float32) - if joints_vis is not None: - target_weight[:, 0] = joints_vis[:, 0] - target = torch.zeros( - (num_joints, heatmap_size[1], heatmap_size[0]), - dtype=torch.float32, - device=cur_device, + img_ori = warp_affine( + img_square, + torch.inverse(uncrop_param["M_square"])[:, :2].to(device), + uncrop_param["ori_shape"], + mode='bilinear', + padding_mode='zeros', + align_corners=True ) - tmp_size = sigma * 3 - - for joint_id in range(num_joints): - mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5) - mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5) - # Check that any part of the gaussian is in-bounds - ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] - br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] - if (ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] or br[0] < 0 or br[1] < 0): - # If not, just return the image as is - target_weight[joint_id] = 0 - continue - - # # Generate gaussian - size = 2 * tmp_size + 1 - # x = np.arange(0, size, 1, np.float32) - # y = x[:, np.newaxis] - # x0 = y0 = size // 2 - # # The gaussian is not normalized, we want the center value to equal 1 - # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) - # g = torch.from_numpy(g.astype(np.float32)) - - x = torch.arange(0, size, dtype=torch.float32, device=cur_device) - y = x.unsqueeze(-1) - x0 = y0 = size // 2 - # The gaussian is not normalized, we want the center value to equal 1 - g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) - - # Usable gaussian range - g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0] - g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1] - # Image range - img_x = max(0, ul[0]), min(br[0], heatmap_size[0]) - img_y = max(0, ul[1]), min(br[1], heatmap_size[1]) - - v = target_weight[joint_id] - if v > 0.5: - target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] - - return target, target_weight + return img_ori diff --git a/lib/common/libmesh/inside_mesh.py b/lib/common/libmesh/inside_mesh.py index 110ff87407efc49010328764299b824f647708cb..eaac43c2e6fe103c6a1dd4e182642ff0cc6a024a 100644 --- a/lib/common/libmesh/inside_mesh.py +++ b/lib/common/libmesh/inside_mesh.py @@ -5,7 +5,7 @@ from .triangle_hash import TriangleHash as _TriangleHash def check_mesh_contains(mesh, points, hash_resolution=512): intersector = MeshIntersector(mesh, hash_resolution) contains, hole_points = intersector.query(points) - return contains, hole_points + return contains, hole_points class MeshIntersector: @@ -25,8 +25,7 @@ class MeshIntersector: # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5)) triangles2d = triangles[:, :, :2] - self._tri_intersector2d = TriangleIntersector2d( - triangles2d, resolution) + self._tri_intersector2d = TriangleIntersector2d(triangles2d, resolution) def query(self, points): # Rescale points @@ -38,8 +37,7 @@ class MeshIntersector: # cull points outside of the axis aligned bounding box # this avoids running ray tests unless points are close - inside_aabb = np.all( - (0 <= points) & (points <= self.resolution), axis=1) + inside_aabb = np.all((0 <= points) & (points <= self.resolution), axis=1) if not inside_aabb.any(): return contains, hole_points @@ -48,14 +46,14 @@ class MeshIntersector: points = points[mask] # Compute intersection depth and check order - points_indices, tri_indices = self._tri_intersector2d.query( - points[:, :2]) + points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2]) triangles_intersect = self._triangles[tri_indices] points_intersect = points[points_indices] depth_intersect, abs_n_2 = self.compute_intersection_depth( - points_intersect, triangles_intersect) + points_intersect, triangles_intersect + ) # Count number of intersections in both directions smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2 @@ -73,7 +71,7 @@ class MeshIntersector: # print('Warning: contains1 != contains2 for some points.') contains[mask] = (contains1 & contains2) hole_points[mask] = np.logical_xor(contains1, contains2) - return contains, hole_points + return contains, hole_points def compute_intersection_depth(self, points, triangles): t1 = triangles[:, 0, :] @@ -150,7 +148,7 @@ class TriangleIntersector2d: sum_uv = u + v contains[mask] = ( - (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) - & (0 < sum_uv) & (sum_uv < abs_detA) + (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) & + (sum_uv < abs_detA) ) return contains diff --git a/lib/common/libmesh/setup.py b/lib/common/libmesh/setup.py index a565e470dd6bb6a2042c86b47a1524c5f7194d58..38ac162300df4e987134e81306e1a6ad674a5323 100644 --- a/lib/common/libmesh/setup.py +++ b/lib/common/libmesh/setup.py @@ -2,7 +2,4 @@ from setuptools import setup from Cython.Build import cythonize import numpy - -setup(name = 'libmesh', - ext_modules = cythonize("*.pyx"), - include_dirs=[numpy.get_include()]) +setup(name='libmesh', ext_modules=cythonize("*.pyx"), include_dirs=[numpy.get_include()]) diff --git a/lib/common/libvoxelize/setup.py b/lib/common/libvoxelize/setup.py index 7a4056e8914dbc65b4fe99acc4d7e3e9f49a04e6..1a534ece09af40fbabd3221eae2e2f5d7931f80c 100644 --- a/lib/common/libvoxelize/setup.py +++ b/lib/common/libvoxelize/setup.py @@ -1,5 +1,4 @@ from setuptools import setup from Cython.Build import cythonize -setup(name = 'libvoxelize', - ext_modules = cythonize("*.pyx")) +setup(name='libvoxelize', ext_modules=cythonize("*.pyx")) diff --git a/lib/common/local_affine.py b/lib/common/local_affine.py index 6cbaa8f0626214c518f95551cfae1ba78a60fc43..3a6ef580ebb306bc8f3a3fbb87200b33df262c68 100644 --- a/lib/common/local_affine.py +++ b/lib/common/local_affine.py @@ -16,7 +16,6 @@ from lib.common.train_util import init_loss # reference: https://github.com/wuhaozhe/pytorch-nicp class LocalAffine(nn.Module): - def __init__(self, num_points, batch_size=1, edges=None): ''' specify the number of points, the number of points should be constant across the batch @@ -26,8 +25,14 @@ class LocalAffine(nn.Module): add additional pooling on top of w matrix ''' super(LocalAffine, self).__init__() - self.A = nn.Parameter(torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1)) - self.b = nn.Parameter(torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(batch_size, num_points, 1, 1)) + self.A = nn.Parameter( + torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1) + ) + self.b = nn.Parameter( + torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat( + batch_size, num_points, 1, 1 + ) + ) self.edges = edges self.num_points = num_points @@ -38,24 +43,23 @@ class LocalAffine(nn.Module): ''' if self.edges is None: raise Exception("edges cannot be none when calculate stiff") - idx1 = self.edges[:, 0] - idx2 = self.edges[:, 1] affine_weight = torch.cat((self.A, self.b), dim=3) - w1 = torch.index_select(affine_weight, dim=1, index=idx1) - w2 = torch.index_select(affine_weight, dim=1, index=idx2) + w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0]) + w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1]) w_diff = (w1 - w2)**2 w_rigid = (torch.linalg.det(self.A) - 1.0)**2 return w_diff, w_rigid def forward(self, x): ''' - x should have shape of B * N * 3 + x should have shape of B * N * 3 * 1 ''' x = x.unsqueeze(3) out_x = torch.matmul(self.A, x) out_x = out_x + self.b - stiffness, rigid = self.stiffness() out_x.squeeze_(3) + stiffness, rigid = self.stiffness() + return out_x, stiffness, rigid @@ -75,10 +79,16 @@ def register(target_mesh, src_mesh, device): tgt_mesh = trimesh2meshes(target_mesh).to(device) src_verts = src_mesh.verts_padded().clone() - local_affine_model = LocalAffine(src_mesh.verts_padded().shape[1], - src_mesh.verts_padded().shape[0], src_mesh.edges_packed()).to(device) + local_affine_model = LocalAffine( + src_mesh.verts_padded().shape[1], + src_mesh.verts_padded().shape[0], src_mesh.edges_packed() + ).to(device) - optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}], lr=1e-2, amsgrad=True) + optimizer_cloth = torch.optim.Adam( + [{ + 'params': local_affine_model.parameters() + }], lr=1e-2, amsgrad=True + ) scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_cloth, mode="min", @@ -90,28 +100,27 @@ def register(target_mesh, src_mesh, device): losses = init_loss() - loop_cloth = tqdm(range(200)) + loop_cloth = tqdm(range(100)) for i in loop_cloth: optimizer_cloth.zero_grad() - deformed_verts, stiffness, rigid = local_affine_model(src_verts) + deformed_verts, stiffness, rigid = local_affine_model(x=src_verts) src_mesh = src_mesh.update_padded(deformed_verts) # losses for laplacian, edge, normal consistency update_mesh_shape_prior_losses(src_mesh, losses) losses["cloth"]["value"] = chamfer_distance( - x=src_mesh.verts_padded(), - y=tgt_mesh.verts_padded())[0] - - losses["stiffness"]["value"] = torch.mean(stiffness) + x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded() + )[0] + losses["stiff"]["value"] = torch.mean(stiffness) losses["rigid"]["value"] = torch.mean(rigid) # Weighted sum of the losses cloth_loss = torch.tensor(0.0, requires_grad=True).to(device) - pbar_desc = "Register SMPL-X towards ECON --- " + pbar_desc = "Register SMPL-X -> d-BiNI -- " for k in losses.keys(): if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0: @@ -119,7 +128,7 @@ def register(target_mesh, src_mesh, device): losses[k]["value"] * losses[k]["weight"] pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | " - pbar_desc += f"Total: {cloth_loss:.5f}" + pbar_desc += f"TOTAL: {cloth_loss:.3f}" loop_cloth.set_description(pbar_desc) # update params @@ -131,6 +140,7 @@ def register(target_mesh, src_mesh, device): src_mesh.verts_packed().detach().squeeze(0).cpu(), src_mesh.faces_packed().detach().squeeze(0).cpu(), process=False, - maintains_order=True) + maintains_order=True + ) return final diff --git a/lib/common/render.py b/lib/common/render.py index c4407bb3e0035d21300c87dd198f573d767d21e1..392566099a1034eb9138d905923248465a625aad 100644 --- a/lib/common/render.py +++ b/lib/common/render.py @@ -31,7 +31,8 @@ from pytorch3d.renderer import ( ) from pytorch3d.renderer.mesh import TexturesVertex from pytorch3d.structures import Meshes -from lib.dataset.mesh_util import get_visibility, blend_rgb_norm +from lib.dataset.mesh_util import get_visibility +from lib.common.imutils import blend_rgb_norm import lib.common.render_utils as util import torch @@ -74,20 +75,23 @@ def query_color(verts, faces, image, device): (xy, z) = verts.split([2, 1], dim=1) visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten() - uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2] + uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2] uv = uv * torch.tensor([1.0, -1.0]).type_as(uv) colors = ( - (torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :, 0].permute(1, 0) + - 1.0) * 0.5 * 255.0) + ( + torch.nn.functional.grid_sample(image, uv, align_corners=True)[0, :, :, + 0].permute(1, 0) + 1.0 + ) * 0.5 * 255.0 + ) colors[visibility == 0.0] = ( (Meshes(verts.unsqueeze(0), faces.unsqueeze(0)).verts_normals_padded().squeeze(0) + 1.0) * - 0.5 * 255.0)[visibility == 0.0] + 0.5 * 255.0 + )[visibility == 0.0] return colors.detach().cpu() class cleanShader(torch.nn.Module): - def __init__(self, blend_params=None): super().__init__() self.blend_params = blend_params if blend_params is not None else BlendParams() @@ -103,7 +107,6 @@ class cleanShader(torch.nn.Module): class Render: - def __init__(self, size=512, device=torch.device("cuda:0")): self.device = device self.size = size @@ -119,21 +122,30 @@ class Render: self.cam_pos = { "frontback": - torch.tensor([ - (0, self.mesh_y_center, self.dis), - (0, self.mesh_y_center, -self.dis), - ]), + torch.tensor( + [ + (0, self.mesh_y_center, self.dis), + (0, self.mesh_y_center, -self.dis), + ] + ), "four": - torch.tensor([ - (0, self.mesh_y_center, self.dis), - (self.dis, self.mesh_y_center, 0), - (0, self.mesh_y_center, -self.dis), - (-self.dis, self.mesh_y_center, 0), - ]), + torch.tensor( + [ + (0, self.mesh_y_center, self.dis), + (self.dis, self.mesh_y_center, 0), + (0, self.mesh_y_center, -self.dis), + (-self.dis, self.mesh_y_center, 0), + ] + ), "around": - torch.tensor([(100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center, - 100.0 * math.sin(np.pi / 180 * angle)) - for angle in range(0, 360, self.step)]) + torch.tensor( + [ + ( + 100.0 * math.cos(np.pi / 180 * angle), self.mesh_y_center, + 100.0 * math.sin(np.pi / 180 * angle) + ) for angle in range(0, 360, self.step) + ] + ) } self.type = "color" @@ -153,8 +165,8 @@ class Render: R, T = look_at_view_transform( eye=self.cam_pos[type][idx], - at=((0, self.mesh_y_center, 0),), - up=((0, 1, 0),), + at=((0, self.mesh_y_center, 0), ), + up=((0, 1, 0), ), ) cameras = FoVOrthographicCameras( @@ -167,7 +179,7 @@ class Render: min_y=-100.0, max_x=100.0, min_x=-100.0, - scale_xyz=(self.scale * np.ones(3),) * len(R), + scale_xyz=(self.scale * np.ones(3), ) * len(R), ) return cameras @@ -202,15 +214,17 @@ class Render: cull_backfaces=True, ) - self.silhouetteRas = MeshRasterizer(cameras=camera, - raster_settings=self.raster_settings_silhouette) - self.renderer = MeshRenderer(rasterizer=self.silhouetteRas, - shader=SoftSilhouetteShader()) + self.silhouetteRas = MeshRasterizer( + cameras=camera, raster_settings=self.raster_settings_silhouette + ) + self.renderer = MeshRenderer( + rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader() + ) elif type == "pointcloud": - self.raster_settings_pcd = PointsRasterizationSettings(image_size=self.size, - radius=0.006, - points_per_pixel=10) + self.raster_settings_pcd = PointsRasterizationSettings( + image_size=self.size, radius=0.006, points_per_pixel=10 + ) self.pcdRas = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings_pcd) self.renderer = PointsRenderer( @@ -230,8 +244,12 @@ class Render: V_lst = [] F_lst = [] for V, F in zip(verts, faces): - V_lst.append(torch.tensor(V).float().to(self.device)) - F_lst.append(torch.tensor(F).long().to(self.device)) + if not torch.is_tensor(V): + V_lst.append(torch.tensor(V).float().to(self.device)) + F_lst.append(torch.tensor(F).long().to(self.device)) + else: + V_lst.append(V.float().to(self.device)) + F_lst.append(F.long().to(self.device)) self.meshes = Meshes(V_lst, F_lst).to(self.device) else: # array or tensor @@ -248,7 +266,8 @@ class Render: # texture only support single mesh if len(self.meshes) == 1: self.meshes.textures = TexturesVertex( - verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5) + verts_features=(self.meshes.verts_normals_padded() + 1.0) * 0.5 + ) def get_image(self, cam_type="frontback", type="rgb", bg="gray"): @@ -260,7 +279,8 @@ class Render: current_mesh = self.meshes[mesh_id] current_mesh.textures = TexturesVertex( - verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5) + verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5 + ) if type == "depth": fragments = self.meshRas(current_mesh.extend(len(self.cam_pos[cam_type]))) @@ -276,7 +296,7 @@ class Render: print(f"unknown {type}") if cam_type == 'frontback': - images[1] = torch.flip(images[1], dims=(-1,)) + images[1] = torch.flip(images[1], dims=(-1, )) # images [N_render, 3, res, res] img_lst.append(images.unsqueeze(1)) @@ -287,9 +307,8 @@ class Render: return list(meshes) def get_rendered_video_multi(self, data, save_path): - - width = data["img_raw"].shape[1] - height = data["img_raw"].shape[0] + + height, width = data["img_raw"].shape[2:] fourcc = cv2.VideoWriter_fourcc(*"mp4v") video = cv2.VideoWriter( @@ -302,14 +321,15 @@ class Render: pbar = tqdm(range(len(self.meshes))) pbar.set_description(colored(f"Normal Rendering {os.path.basename(save_path)}...", "blue")) - mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh] + mesh_renders = [] #[(N_cam, 3, res, res)*N_mesh] # render all the normals for mesh_id in pbar: current_mesh = self.meshes[mesh_id] current_mesh.textures = TexturesVertex( - verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5) + verts_features=(current_mesh.verts_normals_padded() + 1.0) * 0.5 + ) norm_lst = [] @@ -320,21 +340,33 @@ class Render: self.init_renderer(batch_cams, "mesh", "gray") norm_lst.append( - self.renderer(current_mesh.extend(len(batch_cams_idx)))[..., :3].permute( - 0, 3, 1, 2)) + self.renderer(current_mesh.extend(len(batch_cams_idx)) + )[..., :3].permute(0, 3, 1, 2) + ) mesh_renders.append(torch.cat(norm_lst).detach().cpu()) # generate video frame by frame pbar = tqdm(range(len(self.cam_pos["around"]))) pbar.set_description(colored(f"Video Exporting {os.path.basename(save_path)}...", "blue")) + for cam_id in pbar: - img_raw = data["img_raw"].astype(np.uint8) + img_raw = data["img_raw"] num_obj = len(mesh_renders) // 2 - img_smpl = blend_rgb_norm((torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data) - img_cloth = blend_rgb_norm((torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data) + img_smpl = blend_rgb_norm( + (torch.stack(mesh_renders)[:num_obj, cam_id] - 0.5) * 2.0, data + ) + img_cloth = blend_rgb_norm( + (torch.stack(mesh_renders)[num_obj:, cam_id] - 0.5) * 2.0, data + ) - top_img = cv2.resize(np.concatenate([img_raw, img_smpl], axis=1), (width, height // 2)) - final_img = np.concatenate([top_img, img_cloth], axis=0) + top_img = cv2.resize( + torch.cat([img_raw, img_smpl], + dim=-1).squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8), + (width, height // 2) + ) + final_img = np.concatenate( + [top_img, img_cloth.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8)], axis=0 + ) video.write(final_img[:, :, ::-1]) video.release() diff --git a/lib/common/render_utils.py b/lib/common/render_utils.py index 013f625ec7c62a54e1fd5f5bf8579afdb6561023..cb2ca46f420c063c7a1c6a82276d41c42852e451 100644 --- a/lib/common/render_utils.py +++ b/lib/common/render_utils.py @@ -25,9 +25,7 @@ from pytorch3d.renderer.mesh import rasterize_meshes Tensor = NewType("Tensor", torch.Tensor) -def solid_angles(points: Tensor, - triangles: Tensor, - thresh: float = 1e-8) -> Tensor: +def solid_angles(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor: """Compute solid angle between the input points and triangles Follows the method described in: The Solid Angle of a Plane Triangle @@ -55,9 +53,7 @@ def solid_angles(points: Tensor, norms = torch.norm(centered_tris, dim=-1) # Should be BxQxFx3 - cross_prod = torch.cross(centered_tris[:, :, :, 1], - centered_tris[:, :, :, 2], - dim=-1) + cross_prod = torch.cross(centered_tris[:, :, :, 1], centered_tris[:, :, :, 2], dim=-1) # Should be BxQxF numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1) del cross_prod @@ -67,8 +63,10 @@ def solid_angles(points: Tensor, dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1) del centered_tris - denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + - dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0]) + denominator = ( + norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] + dot02 * norms[:, :, :, 1] + + dot12 * norms[:, :, :, 0] + ) del dot01, dot12, dot02, norms # Should be BxQ @@ -80,9 +78,7 @@ def solid_angles(points: Tensor, return 2 * solid_angle -def winding_numbers(points: Tensor, - triangles: Tensor, - thresh: float = 1e-8) -> Tensor: +def winding_numbers(points: Tensor, triangles: Tensor, thresh: float = 1e-8) -> Tensor: """Uses winding_numbers to compute inside/outside Robust inside-outside segmentation using generalized winding numbers Alec Jacobson, @@ -109,8 +105,7 @@ def winding_numbers(points: Tensor, """ # The generalized winding number is the sum of solid angles of the point # with respect to all triangles. - return (1 / (4 * math.pi) * - solid_angles(points, triangles, thresh=thresh).sum(dim=-1)) + return (1 / (4 * math.pi) * solid_angles(points, triangles, thresh=thresh).sum(dim=-1)) def batch_contains(verts, faces, points): @@ -124,8 +119,7 @@ def batch_contains(verts, faces, points): contains = torch.zeros(B, N) for i in range(B): - contains[i] = torch.as_tensor( - trimesh.Trimesh(verts[i], faces[i]).contains(points[i])) + contains[i] = torch.as_tensor(trimesh.Trimesh(verts[i], faces[i]).contains(points[i])) return 2.0 * (contains - 0.5) @@ -155,8 +149,7 @@ def face_vertices(vertices, faces): bs, nv = vertices.shape[:2] bs, nf = faces.shape[:2] device = vertices.device - faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * - nv)[:, None, None] + faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] vertices = vertices.reshape((bs * nv, vertices.shape[-1])) return vertices[faces.long()] @@ -168,7 +161,6 @@ class Pytorch3dRasterizer(nn.Module): x,y,z are in image space, normalized can only render squared image now """ - def __init__(self, image_size=224, blur_radius=0.0, faces_per_pixel=1): """ use fixed raster_settings for rendering faces @@ -189,8 +181,7 @@ class Pytorch3dRasterizer(nn.Module): def forward(self, vertices, faces, attributes=None): fixed_vertices = vertices.clone() fixed_vertices[..., :2] = -fixed_vertices[..., :2] - meshes_screen = Meshes(verts=fixed_vertices.float(), - faces=faces.long()) + meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long()) raster_settings = self.raster_settings pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( meshes_screen, @@ -204,8 +195,9 @@ class Pytorch3dRasterizer(nn.Module): vismask = (pix_to_face > -1).float() D = attributes.shape[-1] attributes = attributes.clone() - attributes = attributes.view(attributes.shape[0] * attributes.shape[1], - 3, attributes.shape[-1]) + attributes = attributes.view( + attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1] + ) N, H, W, K, _ = bary_coords.shape mask = pix_to_face == -1 pix_to_face = pix_to_face.clone() @@ -213,8 +205,7 @@ class Pytorch3dRasterizer(nn.Module): idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) - pixel_vals[mask] = 0 # Replace masked values in output. + pixel_vals[mask] = 0 # Replace masked values in output. pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) - pixel_vals = torch.cat( - [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) + pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) return pixel_vals diff --git a/lib/common/seg3d_lossless.py b/lib/common/seg3d_lossless.py index 6afdc822369608dc89a4215c7f68f4041523ed30..4f5cba2a1edb3a5df14d17beabb9d296203865c1 100644 --- a/lib/common/seg3d_lossless.py +++ b/lib/common/seg3d_lossless.py @@ -31,7 +31,6 @@ logging.getLogger("lightning").setLevel(logging.ERROR) class Seg3dLossless(nn.Module): - def __init__( self, query_func, @@ -53,19 +52,14 @@ class Seg3dLossless(nn.Module): """ super().__init__() self.query_func = query_func - self.register_buffer( - "b_min", - torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3] - self.register_buffer( - "b_max", - torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3] + self.register_buffer("b_min", torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3] + self.register_buffer("b_max", torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3] # ti.init(arch=ti.cuda) # self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1) if type(resolutions[0]) is int: - resolutions = torch.tensor([(res, res, res) - for res in resolutions]) + resolutions = torch.tensor([(res, res, res) for res in resolutions]) else: resolutions = torch.tensor(resolutions) self.register_buffer("resolutions", resolutions) @@ -87,45 +81,36 @@ class Seg3dLossless(nn.Module): ), f"resolution {resolution} need to be odd becuase of align_corner." # init first resolution - init_coords = create_grid3D(0, - resolutions[-1] - 1, - steps=resolutions[0]) # [N, 3] - init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1, - 1) # [bz, N, 3] + init_coords = create_grid3D(0, resolutions[-1] - 1, steps=resolutions[0]) # [N, 3] + init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1, 1) # [bz, N, 3] self.register_buffer("init_coords", init_coords) # some useful tensors calculated = torch.zeros( - (self.resolutions[-1][2], self.resolutions[-1][1], - self.resolutions[-1][0]), + (self.resolutions[-1][2], self.resolutions[-1][1], self.resolutions[-1][0]), dtype=torch.bool, ) self.register_buffer("calculated", calculated) - gird8_offsets = (torch.stack( - torch.meshgrid( - [ - torch.tensor([-1, 0, 1]), - torch.tensor([-1, 0, 1]), - torch.tensor([-1, 0, 1]), - ], - indexing="ij", - )).int().view(3, -1).t()) # [27, 3] + gird8_offsets = ( + torch.stack( + torch.meshgrid( + [ + torch.tensor([-1, 0, 1]), + torch.tensor([-1, 0, 1]), + torch.tensor([-1, 0, 1]), + ], + indexing="ij", + ) + ).int().view(3, -1).t() + ) # [27, 3] self.register_buffer("gird8_offsets", gird8_offsets) # smooth convs - self.smooth_conv3x3 = SmoothConv3D(in_channels=1, - out_channels=1, - kernel_size=3) - self.smooth_conv5x5 = SmoothConv3D(in_channels=1, - out_channels=1, - kernel_size=5) - self.smooth_conv7x7 = SmoothConv3D(in_channels=1, - out_channels=1, - kernel_size=7) - self.smooth_conv9x9 = SmoothConv3D(in_channels=1, - out_channels=1, - kernel_size=9) + self.smooth_conv3x3 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=3) + self.smooth_conv5x5 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=5) + self.smooth_conv7x7 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=7) + self.smooth_conv9x9 = SmoothConv3D(in_channels=1, out_channels=1, kernel_size=9) @torch.no_grad() def batch_eval(self, coords, **kwargs): @@ -144,7 +129,7 @@ class Seg3dLossless(nn.Module): # query function occupancys = self.query_func(**kwargs, points=coords2D) if type(occupancys) is list: - occupancys = torch.stack(occupancys) # [bz, C, N] + occupancys = torch.stack(occupancys) # [bz, C, N] assert ( len(occupancys.size()) == 3 ), "query_func should return a occupancy with shape of [bz, C, N]" @@ -175,10 +160,9 @@ class Seg3dLossless(nn.Module): # first step if torch.equal(resolution, self.resolutions[0]): - coords = self.init_coords.clone() # torch.long + coords = self.init_coords.clone() # torch.long occupancys = self.batch_eval(coords, **kwargs) - occupancys = occupancys.view(self.batchsize, self.channels, D, - H, W) + occupancys = occupancys.view(self.batchsize, self.channels, D, H, W) if (occupancys > 0.5).sum() == 0: # return F.interpolate( # occupancys, size=(final_D, final_H, final_W), @@ -239,23 +223,22 @@ class Seg3dLossless(nn.Module): with torch.no_grad(): if torch.equal(resolution, self.resolutions[1]): - is_boundary = (self.smooth_conv9x9(is_boundary.float()) - > 0)[0, 0] + is_boundary = (self.smooth_conv9x9(is_boundary.float()) > 0)[0, 0] elif torch.equal(resolution, self.resolutions[2]): - is_boundary = (self.smooth_conv7x7(is_boundary.float()) - > 0)[0, 0] + is_boundary = (self.smooth_conv7x7(is_boundary.float()) > 0)[0, 0] else: - is_boundary = (self.smooth_conv3x3(is_boundary.float()) - > 0)[0, 0] + is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0] coords_accum = coords_accum.long() is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1], coords_accum[0, :, 0], ] = False - point_coords = (is_boundary.permute( - 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)) - point_indices = (point_coords[:, :, 2] * H * W + - point_coords[:, :, 1] * W + - point_coords[:, :, 0]) + point_coords = ( + is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0) + ) + point_indices = ( + point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W + + point_coords[:, :, 0] + ) R, C, D, H, W = occupancys.shape @@ -269,13 +252,15 @@ class Seg3dLossless(nn.Module): # put mask point predictions to the right places on the upsampled grid. R, C, D, H, W = occupancys.shape point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) - occupancys = (occupancys.reshape(R, C, D * H * W).scatter_( - 2, point_indices, occupancys_topk).view(R, C, D, H, W)) + occupancys = ( + occupancys.reshape(R, C, + D * H * W).scatter_(2, point_indices, + occupancys_topk).view(R, C, D, H, W) + ) with torch.no_grad(): voxels = coords / stride - coords_accum = torch.cat([voxels, coords_accum], - dim=1).unique(dim=1) + coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1) return occupancys[0, 0] @@ -300,18 +285,16 @@ class Seg3dLossless(nn.Module): # first step if torch.equal(resolution, self.resolutions[0]): - coords = self.init_coords.clone() # torch.long + coords = self.init_coords.clone() # torch.long occupancys = self.batch_eval(coords, **kwargs) - occupancys = occupancys.view(self.batchsize, self.channels, D, - H, W) + occupancys = occupancys.view(self.batchsize, self.channels, D, H, W) if self.visualize: self.plot(occupancys, coords, final_D, final_H, final_W) with torch.no_grad(): coords_accum = coords / stride - calculated[coords[0, :, 2], coords[0, :, 1], - coords[0, :, 0]] = True + calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True # next steps else: @@ -338,35 +321,34 @@ class Seg3dLossless(nn.Module): with torch.no_grad(): # TODO - if self.use_shadow and torch.equal(resolution, - self.resolutions[-1]): + if self.use_shadow and torch.equal(resolution, self.resolutions[-1]): # larger z means smaller depth here depth_res = resolution[2].item() - depth_index = torch.linspace(0, - depth_res - 1, - steps=depth_res).type_as( - occupancys.device) - depth_index_max = (torch.max( - (occupancys > self.balance_value) * - (depth_index + 1), - dim=-1, - keepdim=True, - )[0] - 1) + depth_index = torch.linspace(0, depth_res - 1, + steps=depth_res).type_as(occupancys.device) + depth_index_max = ( + torch.max( + (occupancys > self.balance_value) * (depth_index + 1), + dim=-1, + keepdim=True, + )[0] - 1 + ) shadow = depth_index < depth_index_max is_boundary[shadow] = False is_boundary = is_boundary[0, 0] else: - is_boundary = (self.smooth_conv3x3(is_boundary.float()) - > 0)[0, 0] + is_boundary = (self.smooth_conv3x3(is_boundary.float()) > 0)[0, 0] # is_boundary = is_boundary[0, 0] is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1], coords_accum[0, :, 0], ] = False - point_coords = (is_boundary.permute( - 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)) - point_indices = (point_coords[:, :, 2] * H * W + - point_coords[:, :, 1] * W + - point_coords[:, :, 0]) + point_coords = ( + is_boundary.permute(2, 1, 0).nonzero(as_tuple=False).unsqueeze(0) + ) + point_indices = ( + point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W + + point_coords[:, :, 0] + ) R, C, D, H, W = occupancys.shape # interpolated value @@ -388,28 +370,28 @@ class Seg3dLossless(nn.Module): # put mask point predictions to the right places on the upsampled grid. R, C, D, H, W = occupancys.shape point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) - occupancys = (occupancys.reshape(R, C, D * H * W).scatter_( - 2, point_indices, occupancys_topk).view(R, C, D, H, W)) + occupancys = ( + occupancys.reshape(R, C, + D * H * W).scatter_(2, point_indices, + occupancys_topk).view(R, C, D, H, W) + ) with torch.no_grad(): # conflicts - conflicts = ((occupancys_interp - self.balance_value) * - (occupancys_topk - self.balance_value) < 0)[0, - 0] + conflicts = ( + (occupancys_interp - self.balance_value) * + (occupancys_topk - self.balance_value) < 0 + )[0, 0] if self.visualize: - self.plot(occupancys, coords, final_D, final_H, - final_W) + self.plot(occupancys, coords, final_D, final_H, final_W) voxels = coords / stride - coords_accum = torch.cat([voxels, coords_accum], - dim=1).unique(dim=1) - calculated[coords[0, :, 2], coords[0, :, 1], - coords[0, :, 0]] = True + coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1) + calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True while conflicts.sum() > 0: - if self.use_shadow and torch.equal(resolution, - self.resolutions[-1]): + if self.use_shadow and torch.equal(resolution, self.resolutions[-1]): break with torch.no_grad(): @@ -426,25 +408,27 @@ class Seg3dLossless(nn.Module): ) conflicts_boundary = ( - (conflicts_coords.int() + - self.gird8_offsets.unsqueeze(1) * - stride.int()).reshape(-1, 3).long().unique(dim=0)) - conflicts_boundary[:, - 0] = conflicts_boundary[:, 0].clamp( - 0, - calculated.size(2) - 1) - conflicts_boundary[:, - 1] = conflicts_boundary[:, 1].clamp( - 0, - calculated.size(1) - 1) - conflicts_boundary[:, - 2] = conflicts_boundary[:, 2].clamp( - 0, - calculated.size(0) - 1) - - coords = conflicts_boundary[calculated[ - conflicts_boundary[:, 2], conflicts_boundary[:, 1], - conflicts_boundary[:, 0], ] == False] + ( + conflicts_coords.int() + + self.gird8_offsets.unsqueeze(1) * stride.int() + ).reshape(-1, 3).long().unique(dim=0) + ) + conflicts_boundary[:, 0] = conflicts_boundary[:, 0].clamp( + 0, + calculated.size(2) - 1 + ) + conflicts_boundary[:, 1] = conflicts_boundary[:, 1].clamp( + 0, + calculated.size(1) - 1 + ) + conflicts_boundary[:, 2] = conflicts_boundary[:, 2].clamp( + 0, + calculated.size(0) - 1 + ) + + coords = conflicts_boundary[calculated[conflicts_boundary[:, 2], + conflicts_boundary[:, 1], + conflicts_boundary[:, 0], ] == False] if self.debug: self.plot( @@ -458,9 +442,10 @@ class Seg3dLossless(nn.Module): coords = coords.unsqueeze(0) point_coords = coords / stride - point_indices = (point_coords[:, :, 2] * H * W + - point_coords[:, :, 1] * W + - point_coords[:, :, 0]) + point_indices = ( + point_coords[:, :, 2] * H * W + point_coords[:, :, 1] * W + + point_coords[:, :, 0] + ) R, C, D, H, W = occupancys.shape # interpolated value @@ -481,44 +466,37 @@ class Seg3dLossless(nn.Module): with torch.no_grad(): # conflicts - conflicts = ((occupancys_interp - self.balance_value) * - (occupancys_topk - self.balance_value) < - 0)[0, 0] + conflicts = ( + (occupancys_interp - self.balance_value) * + (occupancys_topk - self.balance_value) < 0 + )[0, 0] # put mask point predictions to the right places on the upsampled grid. - point_indices = point_indices.unsqueeze(1).expand( - -1, C, -1) - occupancys = (occupancys.reshape(R, C, D * H * W).scatter_( - 2, point_indices, occupancys_topk).view(R, C, D, H, W)) + point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) + occupancys = ( + occupancys.reshape(R, C, + D * H * W).scatter_(2, point_indices, + occupancys_topk).view(R, C, D, H, W) + ) with torch.no_grad(): voxels = coords / stride - coords_accum = torch.cat([voxels, coords_accum], - dim=1).unique(dim=1) - calculated[coords[0, :, 2], coords[0, :, 1], - coords[0, :, 0]] = True + coords_accum = torch.cat([voxels, coords_accum], dim=1).unique(dim=1) + calculated[coords[0, :, 2], coords[0, :, 1], coords[0, :, 0]] = True if self.visualize: this_stage_coords = torch.cat(this_stage_coords, dim=1) - self.plot(occupancys, this_stage_coords, final_D, final_H, - final_W) + self.plot(occupancys, this_stage_coords, final_D, final_H, final_W) return occupancys[0, 0] - def plot(self, - occupancys, - coords, - final_D, - final_H, - final_W, - title="", - **kwargs): + def plot(self, occupancys, coords, final_D, final_H, final_W, title="", **kwargs): final = F.interpolate( occupancys.float(), size=(final_D, final_H, final_W), mode="trilinear", align_corners=True, - ) # here true is correct! + ) # here true is correct! x = coords[0, :, 0].to("cpu") y = coords[0, :, 1].to("cpu") z = coords[0, :, 2].to("cpu") @@ -548,20 +526,19 @@ class Seg3dLossless(nn.Module): sdf_all = sdf.permute(2, 1, 0) # shadow - grad_v = (sdf_all > 0.5) * torch.linspace( - resolution, 1, steps=resolution).to(sdf.device) - grad_c = torch.ones_like(sdf_all) * torch.linspace( - 0, resolution - 1, steps=resolution).to(sdf.device) + grad_v = (sdf_all > 0.5) * torch.linspace(resolution, 1, steps=resolution).to(sdf.device) + grad_c = torch.ones_like(sdf_all) * torch.linspace(0, resolution - 1, + steps=resolution).to(sdf.device) max_v, max_c = grad_v.max(dim=2) shadow = grad_c > max_c.view(resolution, resolution, 1) keep = (sdf_all > 0.5) & (~shadow) - p1 = keep.nonzero(as_tuple=False).t() # [3, N] - p2 = p1.clone() # z + p1 = keep.nonzero(as_tuple=False).t() # [3, N] + p2 = p1.clone() # z p2[2, :] = (p2[2, :] - 2).clamp(0, resolution) - p3 = p1.clone() # y + p3 = p1.clone() # y p3[1, :] = (p3[1, :] - 2).clamp(0, resolution) - p4 = p1.clone() # x + p4 = p1.clone() # x p4[0, :] = (p4[0, :] - 2).clamp(0, resolution) v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]] @@ -569,10 +546,10 @@ class Seg3dLossless(nn.Module): v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]] v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]] - X = p1[0, :].long() # [N,] - Y = p1[1, :].long() # [N,] - Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * ( - v2 - 0.5) / (v2 - v1) # [N,] + X = p1[0, :].long() # [N,] + Y = p1[1, :].long() # [N,] + Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + p1[2, :].float() * (v2 - 0.5 + ) / (v2 - v1) # [N,] Z = Z.clamp(0, resolution) # normal @@ -588,8 +565,7 @@ class Seg3dLossless(nn.Module): @torch.no_grad() def render_normal(self, resolution, X, Y, Z, norm): - image = torch.ones((1, 3, resolution, resolution), - dtype=torch.float32).to(norm.device) + image = torch.ones((1, 3, resolution, resolution), dtype=torch.float32).to(norm.device) color = (norm + 1) / 2.0 color = color.clamp(0, 1) image[0, :, Y, X] = color.t() @@ -617,9 +593,9 @@ class Seg3dLossless(nn.Module): def export_mesh(self, occupancys): final = occupancys[1:, 1:, 1:].contiguous() - + verts, faces = marching_cubes(final.unsqueeze(0), isolevel=0.5) verts = verts[0].cpu().float() - faces = faces[0].cpu().long()[:,[0,2,1]] - + faces = faces[0].cpu().long()[:, [0, 2, 1]] + return verts, faces diff --git a/lib/common/seg3d_utils.py b/lib/common/seg3d_utils.py index 958fb338f02d814c1d73c37edbcad777c2cff9fe..bee264615a54777bb948414c82f502c678664329 100644 --- a/lib/common/seg3d_utils.py +++ b/lib/common/seg3d_utils.py @@ -20,11 +20,7 @@ import torch.nn.functional as F import matplotlib.pyplot as plt -def plot_mask2D(mask, - title="", - point_coords=None, - figsize=10, - point_marker_size=5): +def plot_mask2D(mask, title="", point_coords=None, figsize=10, point_marker_size=5): ''' Simple plotting tool to show intermediate mask predictions and points where PointRend is applied. @@ -46,26 +42,19 @@ def plot_mask2D(mask, plt.xlabel(W, fontsize=30) plt.xticks([], []) plt.yticks([], []) - plt.imshow(mask.detach(), - interpolation="nearest", - cmap=plt.get_cmap('gray')) + plt.imshow(mask.detach(), interpolation="nearest", cmap=plt.get_cmap('gray')) if point_coords is not None: - plt.scatter(x=point_coords[0], - y=point_coords[1], - color="red", - s=point_marker_size, - clip_on=True) + plt.scatter( + x=point_coords[0], y=point_coords[1], color="red", s=point_marker_size, clip_on=True + ) plt.xlim(-0.5, W - 0.5) plt.ylim(H - 0.5, -0.5) plt.show() -def plot_mask3D(mask=None, - title="", - point_coords=None, - figsize=1500, - point_marker_size=8, - interactive=True): +def plot_mask3D( + mask=None, title="", point_coords=None, figsize=1500, point_marker_size=8, interactive=True +): ''' Simple plotting tool to show intermediate mask predictions and points where PointRend is applied. @@ -90,7 +79,8 @@ def plot_mask3D(mask=None, # marching cube to find surface verts, faces, normals, values = measure.marching_cubes_lewiner( - mask, 0.5, gradient_direction='ascent') + mask, 0.5, gradient_direction='ascent' + ) # create a mesh mesh = trimesh.Trimesh(verts, faces) @@ -110,57 +100,49 @@ def plot_mask3D(mask=None, pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red') vis_list.append(pc) - vp.show(*vis_list, - bg="white", - axes=1, - interactive=interactive, - azimuth=30, - elevation=30) + vp.show(*vis_list, bg="white", axes=1, interactive=interactive, azimuth=30, elevation=30) def create_grid3D(min, max, steps): if type(min) is int: - min = (min, min, min) # (x, y, z) + min = (min, min, min) # (x, y, z) if type(max) is int: - max = (max, max, max) # (x, y) + max = (max, max, max) # (x, y) if type(steps) is int: - steps = (steps, steps, steps) # (x, y, z) + steps = (steps, steps, steps) # (x, y, z) arrangeX = torch.linspace(min[0], max[0], steps[0]).long() arrangeY = torch.linspace(min[1], max[1], steps[1]).long() arrangeZ = torch.linspace(min[2], max[2], steps[2]).long() - gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX], - indexing='ij') - coords = torch.stack([gridW, girdH, - gridD]) # [2, steps[0], steps[1], steps[2]] - coords = coords.view(3, -1).t() # [N, 3] + gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX], indexing='ij') + coords = torch.stack([gridW, girdH, gridD]) # [2, steps[0], steps[1], steps[2]] + coords = coords.view(3, -1).t() # [N, 3] return coords def create_grid2D(min, max, steps): if type(min) is int: - min = (min, min) # (x, y) + min = (min, min) # (x, y) if type(max) is int: - max = (max, max) # (x, y) + max = (max, max) # (x, y) if type(steps) is int: - steps = (steps, steps) # (x, y) + steps = (steps, steps) # (x, y) arrangeX = torch.linspace(min[0], max[0], steps[0]).long() arrangeY = torch.linspace(min[1], max[1], steps[1]).long() girdH, gridW = torch.meshgrid([arrangeY, arrangeX], indexing='ij') - coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]] - coords = coords.view(2, -1).t() # [N, 2] + coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]] + coords = coords.view(2, -1).t() # [N, 2] return coords class SmoothConv2D(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" self.padding = (kernel_size - 1) // 2 weight = torch.ones( - (in_channels, out_channels, kernel_size, kernel_size), - dtype=torch.float32) / (kernel_size**2) + (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32 + ) / (kernel_size**2) self.register_buffer('weight', weight) def forward(self, input): @@ -168,53 +150,49 @@ class SmoothConv2D(nn.Module): class SmoothConv3D(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" self.padding = (kernel_size - 1) // 2 weight = torch.ones( - (in_channels, out_channels, kernel_size, kernel_size, kernel_size), - dtype=torch.float32) / (kernel_size**3) + (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32 + ) / (kernel_size**3) self.register_buffer('weight', weight) def forward(self, input): return F.conv3d(input, self.weight, padding=self.padding) -def build_smooth_conv3D(in_channels=1, - out_channels=1, - kernel_size=3, - padding=1): - smooth_conv = torch.nn.Conv3d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - padding=padding) +def build_smooth_conv3D(in_channels=1, out_channels=1, kernel_size=3, padding=1): + smooth_conv = torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding + ) smooth_conv.weight.data = torch.ones( - (in_channels, out_channels, kernel_size, kernel_size, kernel_size), - dtype=torch.float32) / (kernel_size**3) + (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32 + ) / (kernel_size**3) smooth_conv.bias.data = torch.zeros(out_channels) return smooth_conv -def build_smooth_conv2D(in_channels=1, - out_channels=1, - kernel_size=3, - padding=1): - smooth_conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - padding=padding) +def build_smooth_conv2D(in_channels=1, out_channels=1, kernel_size=3, padding=1): + smooth_conv = torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding + ) smooth_conv.weight.data = torch.ones( - (in_channels, out_channels, kernel_size, kernel_size), - dtype=torch.float32) / (kernel_size**2) + (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32 + ) / (kernel_size**2) smooth_conv.bias.data = torch.zeros(out_channels) return smooth_conv -def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, - **kwargs): +def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, **kwargs): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: @@ -233,28 +211,21 @@ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, # d_step = 1.0 / float(D) num_points = min(D * H * W, num_points) - point_scores, point_indices = torch.topk(uncertainty_map.view( - R, D * H * W), - k=num_points, - dim=1) - point_coords = torch.zeros(R, - num_points, - 3, - dtype=torch.float, - device=uncertainty_map.device) + point_scores, point_indices = torch.topk( + uncertainty_map.view(R, D * H * W), k=num_points, dim=1 + ) + point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device) # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step - point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x - point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y - point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z - print(f"resolution {D} x {H} x {W}", point_scores.min(), - point_scores.max()) + point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x + point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y + point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z + print(f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) return point_indices, point_coords -def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, - clip_min): +def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, clip_min): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: @@ -276,28 +247,21 @@ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, uncertainty_map = uncertainty_map.view(D * H * W) indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) num_points = min(num_points, indices.size(0)) - point_scores, point_indices = torch.topk(uncertainty_map[indices], - k=num_points, - dim=0) + point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0) point_indices = indices[point_indices].unsqueeze(0) - point_coords = torch.zeros(R, - num_points, - 3, - dtype=torch.float, - device=uncertainty_map.device) + point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device) # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step - point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x - point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y - point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z + point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x + point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y + point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) return point_indices, point_coords -def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, - **kwargs): +def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, **kwargs): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: @@ -315,14 +279,8 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, # w_step = 1.0 / float(W) num_points = min(H * W, num_points) - point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), - k=num_points, - dim=1) - point_coords = torch.zeros(R, - num_points, - 2, - dtype=torch.long, - device=uncertainty_map.device) + point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1) + point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device) # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step point_coords[:, :, 0] = (point_indices % W).to(torch.long) @@ -331,8 +289,7 @@ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, return point_indices, point_coords -def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, - clip_min): +def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, clip_min): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: @@ -353,16 +310,10 @@ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, uncertainty_map = uncertainty_map.view(H * W) indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) num_points = min(num_points, indices.size(0)) - point_scores, point_indices = torch.topk(uncertainty_map[indices], - k=num_points, - dim=0) + point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0) point_indices = indices[point_indices].unsqueeze(0) - point_coords = torch.zeros(R, - num_points, - 2, - dtype=torch.long, - device=uncertainty_map.device) + point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device) # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step point_coords[:, :, 0] = (point_indices % W).to(torch.long) @@ -388,7 +339,6 @@ def calculate_uncertainty(logits, classes=None, balance_value=0.5): if logits.shape[1] == 1: gt_class_logits = logits else: - gt_class_logits = logits[ - torch.arange(logits.shape[0], device=logits.device), - classes].unsqueeze(1) + gt_class_logits = logits[torch.arange(logits.shape[0], device=logits.device), + classes].unsqueeze(1) return -torch.abs(gt_class_logits - balance_value) diff --git a/lib/common/train_util.py b/lib/common/train_util.py index 06ff842e30fab3f29b4046a518f48f46bcb8eb76..a39102a5849e25d056b9d96d2df9538790bec6ea 100644 --- a/lib/common/train_util.py +++ b/lib/common/train_util.py @@ -14,63 +14,62 @@ # # Contact: ps-license@tuebingen.mpg.de -import yaml -import os.path as osp import torch -import numpy as np from ..dataset.mesh_util import * from ..net.geometry import orthogonal -import cv2, PIL -from tqdm import tqdm -import os from termcolor import colored import pytorch_lightning as pl +class Format: + end = '\033[0m' + start = '\033[4m' + + def init_loss(): losses = { - # Cloth: Normal_recon - Normal_pred + # Cloth: chamfer distance "cloth": { "weight": 1e3, "value": 0.0 }, - # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2) - "stiffness": { + # Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2) + "stiff": { "weight": 1e5, "value": 0.0 }, - # Cloth: det(R) = 1 + # Cloth: det(R) = 1 "rigid": { "weight": 1e5, "value": 0.0 }, - # Cloth: edge length + # Cloth: edge length "edge": { "weight": 0, "value": 0.0 }, - # Cloth: normal consistency + # Cloth: normal consistency "nc": { "weight": 0, "value": 0.0 }, - # Cloth: laplacian smoonth - "laplacian": { + # Cloth: laplacian smoonth + "lapla": { "weight": 1e2, "value": 0.0 }, - # Body: Normal_pred - Normal_smpl + # Body: Normal_pred - Normal_smpl "normal": { "weight": 1e0, "value": 0.0 }, - # Body: Silhouette_pred - Silhouette_smpl + # Body: Silhouette_pred - Silhouette_smpl "silhouette": { "weight": 1e0, "value": 0.0 }, - # Joint: reprojected joints difference + # Joint: reprojected joints difference "joint": { "weight": 5e0, "value": 0.0 @@ -81,7 +80,6 @@ def init_loss(): class SubTrainer(pl.Trainer): - def save_checkpoint(self, filepath, weights_only=False): """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -101,214 +99,6 @@ class SubTrainer(pl.Trainer): pl.utilities.cloud_io.atomic_save(_checkpoint, filepath) -def rename(old_dict, old_name, new_name): - new_dict = {} - for key, value in zip(old_dict.keys(), old_dict.values()): - new_key = key if key != old_name else new_name - new_dict[new_key] = old_dict[key] - return new_dict - - -def load_normal_networks(model, normal_path): - - pretrained_dict = torch.load( - normal_path, - map_location=model.device)["state_dict"] - model_dict = model.state_dict() - - # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if k in model_dict and v.shape == model_dict[k].shape - } - - # # 2. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - # 3. load the new state dict - model.load_state_dict(model_dict) - - del pretrained_dict - del model_dict - - print(colored(f"Resume Normal weights from {normal_path}", "green")) - - -def load_networks(model, mlp_path, normal_path=None): - - model_dict = model.state_dict() - main_dict = {} - normal_dict = {} - - # MLP part loading - if os.path.exists(mlp_path) and mlp_path.endswith("ckpt"): - main_dict = torch.load( - mlp_path, - map_location=model.device)["state_dict"] - - main_dict = { - k: v - for k, v in main_dict.items() - if k in model_dict and v.shape == model_dict[k].shape and ( - "reconEngine" not in k) and ("normal_filter" not in k) and ( - "voxelization" not in k) - } - print(colored(f"Resume MLP weights from {mlp_path}", "green")) - - # normal network part loading - if normal_path is not None and os.path.exists(normal_path) and normal_path.endswith("ckpt"): - normal_dict = torch.load( - normal_path, - map_location=model.device)["state_dict"] - - for key in normal_dict.keys(): - normal_dict = rename(normal_dict, key, - key.replace("netG", "netG.normal_filter")) - - normal_dict = { - k: v - for k, v in normal_dict.items() - if k in model_dict and v.shape == model_dict[k].shape - } - print(colored(f"Resume normal model from {normal_path}", "green")) - - model_dict.update(main_dict) - model_dict.update(normal_dict) - model.load_state_dict(model_dict) - - # clean unused GPU memory - del main_dict - del normal_dict - del model_dict - torch.cuda.empty_cache() - - -def reshape_sample_tensor(sample_tensor, num_views): - if num_views == 1: - return sample_tensor - # Need to repeat sample_tensor along the batch dim num_views times - sample_tensor = sample_tensor.unsqueeze(dim=1) - sample_tensor = sample_tensor.repeat(1, num_views, 1, 1) - sample_tensor = sample_tensor.view( - sample_tensor.shape[0] * sample_tensor.shape[1], - sample_tensor.shape[2], - sample_tensor.shape[3], - ) - return sample_tensor - - -def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): - """Sets the learning rate to the initial LR decayed by schedule""" - if epoch in schedule: - lr *= gamma - for param_group in optimizer.param_groups: - param_group["lr"] = lr - return lr - - -def compute_acc(pred, gt, thresh=0.5): - """ - return: - IOU, precision, and recall - """ - with torch.no_grad(): - vol_pred = pred > thresh - vol_gt = gt > thresh - - union = vol_pred | vol_gt - inter = vol_pred & vol_gt - - true_pos = inter.sum().float() - - union = union.sum().float() - if union == 0: - union = 1 - vol_pred = vol_pred.sum().float() - if vol_pred == 0: - vol_pred = 1 - vol_gt = vol_gt.sum().float() - if vol_gt == 0: - vol_gt = 1 - return true_pos / union, true_pos / vol_pred, true_pos / vol_gt - -def calc_error(opt, net, cuda, dataset, num_tests): - if num_tests > len(dataset): - num_tests = len(dataset) - with torch.no_grad(): - erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], [] - for idx in tqdm(range(num_tests)): - data = dataset[idx * len(dataset) // num_tests] - # retrieve the data - image_tensor = data["img"].to(device=cuda) - calib_tensor = data["calib"].to(device=cuda) - sample_tensor = data["samples"].to(device=cuda).unsqueeze(0) - if opt.num_views > 1: - sample_tensor = reshape_sample_tensor(sample_tensor, - opt.num_views) - label_tensor = data["labels"].to(device=cuda).unsqueeze(0) - - res, error = net.forward(image_tensor, - sample_tensor, - calib_tensor, - labels=label_tensor) - - IOU, prec, recall = compute_acc(res, label_tensor) - - # print( - # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}' - # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item())) - erorr_arr.append(error.item()) - IOU_arr.append(IOU.item()) - prec_arr.append(prec.item()) - recall_arr.append(recall.item()) - - return ( - np.average(erorr_arr), - np.average(IOU_arr), - np.average(prec_arr), - np.average(recall_arr), - ) - - -def calc_error_color(opt, netG, netC, cuda, dataset, num_tests): - if num_tests > len(dataset): - num_tests = len(dataset) - with torch.no_grad(): - error_color_arr = [] - - for idx in tqdm(range(num_tests)): - data = dataset[idx * len(dataset) // num_tests] - # retrieve the data - image_tensor = data["img"].to(device=cuda) - calib_tensor = data["calib"].to(device=cuda) - color_sample_tensor = data["color_samples"].to( - device=cuda).unsqueeze(0) - - if opt.num_views > 1: - color_sample_tensor = reshape_sample_tensor( - color_sample_tensor, opt.num_views) - - rgb_tensor = data["rgbs"].to(device=cuda).unsqueeze(0) - - netG.filter(image_tensor) - _, errorC = netC.forward( - image_tensor, - netG.get_im_feat(), - color_sample_tensor, - calib_tensor, - labels=rgb_tensor, - ) - - # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}' - # .format(idx, num_tests, errorG.item(), errorC.item())) - error_color_arr.append(errorC.item()) - - return np.average(error_color_arr) - - -# pytorch lightning training related fucntions - - def query_func(opt, netG, features, points, proj_matrix=None): """ - points: size of (bz, N, 3) @@ -317,7 +107,7 @@ def query_func(opt, netG, features, points, proj_matrix=None): """ assert len(points) == 1 samples = points.repeat(opt.num_views, 1, 1) - samples = samples.permute(0, 2, 1) # [bz, 3, N] + samples = samples.permute(0, 2, 1) # [bz, 3, N] # view specific query if proj_matrix is not None: @@ -337,85 +127,25 @@ def query_func(opt, netG, features, points, proj_matrix=None): return preds + def query_func_IF(batch, netG, points): """ - points: size of (bz, N, 3) return: size of (bz, 1, N) """ - + batch["samples_geo"] = points batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points) - + preds = netG(batch) return preds.unsqueeze(1) -def isin(ar1, ar2): - return (ar1[..., None] == ar2).any(-1) - - -def in1d(ar1, ar2): - mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool) - mask[ar2.unique()] = True - return mask[ar1] - def batch_mean(res, key): - return torch.stack([ - x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) - for x in res - ]).mean() - - -def tf_log_convert(log_dict): - new_log_dict = log_dict.copy() - for k, v in log_dict.items(): - new_log_dict[k.replace("_", "/")] = v - del new_log_dict[k] - - return new_log_dict - - -def bar_log_convert(log_dict, name=None, rot=None): - from decimal import Decimal - - new_log_dict = {} - - if name is not None: - new_log_dict["name"] = name[0] - if rot is not None: - new_log_dict["rot"] = rot[0] - - for k, v in log_dict.items(): - color = "yellow" - if "loss" in k: - color = "red" - k = k.replace("loss", "L") - elif "acc" in k: - color = "green" - k = k.replace("acc", "A") - elif "iou" in k: - color = "green" - k = k.replace("iou", "I") - elif "prec" in k: - color = "green" - k = k.replace("prec", "P") - elif "recall" in k: - color = "green" - k = k.replace("recall", "R") - - if "lr" not in k: - new_log_dict[colored(k.split("_")[1], - color)] = colored(f"{v:.3f}", color) - else: - new_log_dict[colored(k.split("_")[1], - color)] = colored(f"{Decimal(str(v)):.1E}", - color) - - if "loss" in new_log_dict.keys(): - del new_log_dict["loss"] - - return new_log_dict + return torch.stack( + [x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res] + ).mean() def accumulate(outputs, rot_num, split): @@ -430,160 +160,10 @@ def accumulate(outputs, rot_num, split): keyword = f"{dataset}/{metric}" if keyword not in hparam_log_dict.keys(): hparam_log_dict[keyword] = 0 - for idx in range(split[dataset][0] * rot_num, - split[dataset][1] * rot_num): + for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num): hparam_log_dict[keyword] += outputs[idx][metric].item() - hparam_log_dict[keyword] /= (split[dataset][1] - - split[dataset][0]) * rot_num + hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num print(colored(hparam_log_dict, "green")) return hparam_log_dict - - -def calc_error_N(outputs, targets): - """calculate the error of normal (IGR) - - Args: - outputs (torch.tensor): [B, 3, N] - target (torch.tensor): [B, N, 3] - - # manifold loss and grad_loss in IGR paper - grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean() - normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean() - - Returns: - torch.tensor: error of valid normals on the surface - """ - # outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3)) - outputs = -outputs.permute(0, 2, 1).reshape(-1, 1) - targets = targets.reshape(-1, 3)[:, 2:3] - with_normals = targets.sum(dim=1).abs() > 0.0 - - # eikonal loss - grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean() - # normals loss - normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean() - - return grad_loss * 0.0 + normal_loss - - -def calc_knn_acc(preds, carn_verts, labels, pick_num): - """calculate knn accuracy - - Args: - preds (torch.tensor): [B, 3, N] - carn_verts (torch.tensor): [SMPLX_V_num, 3] - labels (torch.tensor): [B, N_knn, N] - """ - N_knn_full = labels.shape[1] - preds = preds.permute(0, 2, 1).reshape(-1, 3) - labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn] - labels = labels[:, :pick_num] - - dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num] - knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn] - cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0] - bool_col = torch.zeros_like(cat_mat)[:, 0] - for i in range(pick_num * 2 - 1): - bool_col += cat_mat[:, i] == cat_mat[:, i + 1] - acc = (bool_col > 0).sum() / len(bool_col) - - return acc - - -def calc_acc_seg(output, target, num_multiseg): - from pytorch_lightning.metrics import Accuracy - - return Accuracy()(output.reshape(-1, num_multiseg).cpu(), - target.flatten().cpu()) - - -def add_watermark(imgs, titles): - - # Write some Text - - font = cv2.FONT_HERSHEY_SIMPLEX - bottomLeftCornerOfText = (350, 50) - bottomRightCornerOfText = (800, 50) - fontScale = 1 - fontColor = (1.0, 1.0, 1.0) - lineType = 2 - - for i in range(len(imgs)): - - title = titles[i + 1] - cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale, - fontColor, lineType) - - if i == 0: - cv2.putText( - imgs[i], - str(titles[i][0]), - bottomRightCornerOfText, - font, - fontScale, - fontColor, - lineType, - ) - - result = np.concatenate(imgs, axis=0).transpose(2, 0, 1) - - return result - - -def make_test_gif(img_dir): - - if img_dir is not None and len(os.listdir(img_dir)) > 0: - for dataset in os.listdir(img_dir): - for subject in sorted(os.listdir(osp.join(img_dir, dataset))): - img_lst = [] - im1 = None - for file in sorted( - os.listdir(osp.join(img_dir, dataset, subject))): - if file[-3:] not in ["obj", "gif"]: - img_path = os.path.join(img_dir, dataset, subject, - file) - if im1 == None: - im1 = PIL.Image.open(img_path) - else: - img_lst.append(PIL.Image.open(img_path)) - - print(os.path.join(img_dir, dataset, subject, "out.gif")) - im1.save( - os.path.join(img_dir, dataset, subject, "out.gif"), - save_all=True, - append_images=img_lst, - duration=500, - loop=0, - ) - - -def export_cfg(logger, dir, cfg): - - cfg_export_file = osp.join(dir, f"cfg_{logger.version}.yaml") - - if not osp.exists(cfg_export_file): - os.makedirs(osp.dirname(cfg_export_file), exist_ok=True) - with open(cfg_export_file, "w+") as file: - _ = yaml.dump(cfg, file) - - -from yacs.config import CfgNode - -_VALID_TYPES = {tuple, list, str, int, float, bool} - - -def convert_to_dict(cfg_node, key_list=[]): - """ Convert a config node to dictionary """ - if not isinstance(cfg_node, CfgNode): - if type(cfg_node) not in _VALID_TYPES: - print( - "Key {} with value {} is not a valid type; valid types: {}". - format(".".join(key_list), type(cfg_node), _VALID_TYPES), ) - return cfg_node - else: - cfg_dict = dict(cfg_node) - for k, v in cfg_dict.items(): - cfg_dict[k] = convert_to_dict(v, key_list + [k]) - return cfg_dict diff --git a/lib/common/voxelize.py b/lib/common/voxelize.py index 112eb4ad82dc763e132ad7da49bb4eb409b15629..f792189ccc185e9a7b596eae5a9230fe21482aef 100644 --- a/lib/common/voxelize.py +++ b/lib/common/voxelize.py @@ -13,6 +13,7 @@ from lib.common.libmesh.inside_mesh import check_mesh_contains # From Occupancy Networks, Mescheder et. al. CVPR'19 + def make_3d_grid(bb_min, bb_max, shape): ''' Makes a 3D grid. @@ -37,7 +38,7 @@ def make_3d_grid(bb_min, bb_max, shape): class VoxelGrid: def __init__(self, data, loc=(0., 0., 0.), scale=1): - assert(data.shape[0] == data.shape[1] == data.shape[2]) + assert (data.shape[0] == data.shape[1] == data.shape[2]) data = np.asarray(data, dtype=np.bool) loc = np.asarray(loc) self.data = data @@ -53,7 +54,7 @@ class VoxelGrid: # Default scale, scales the mesh to [-0.45, 0.45]^3 if scale is None: - scale = (bounds[1] - bounds[0]).max()/0.9 + scale = (bounds[1] - bounds[0]).max() / 0.9 loc = np.asarray(loc) scale = float(scale) @@ -61,7 +62,7 @@ class VoxelGrid: # Transform mesh mesh = mesh.copy() mesh.apply_translation(-loc) - mesh.apply_scale(1/scale) + mesh.apply_scale(1 / scale) # Apply method if method == 'ray': @@ -75,7 +76,7 @@ class VoxelGrid: def down_sample(self, factor=2): if not (self.resolution % factor) == 0: raise ValueError('Resolution must be divisible by factor.') - new_data = block_reduce(self.data, (factor,) * 3, np.max) + new_data = block_reduce(self.data, (factor, ) * 3, np.max) return VoxelGrid(new_data, self.loc, self.scale) def to_mesh(self): @@ -103,9 +104,9 @@ class VoxelGrid: f2 = f2_r | f2_l f3 = f3_r | f3_l - assert(f1.shape == (nx + 1, ny, nz)) - assert(f2.shape == (nx, ny + 1, nz)) - assert(f3.shape == (nx, ny, nz + 1)) + assert (f1.shape == (nx + 1, ny, nz)) + assert (f2.shape == (nx, ny + 1, nz)) + assert (f3.shape == (nx, ny, nz + 1)) # Determine if vertex present v = np.full(grid_shape, False) @@ -146,53 +147,76 @@ class VoxelGrid: f2_r_x, f2_r_y, f2_r_z = np.where(f2_r) f3_r_x, f3_r_y, f3_r_z = np.where(f3_r) - faces_1_l = np.stack([ - v_idx[f1_l_x, f1_l_y, f1_l_z], - v_idx[f1_l_x, f1_l_y, f1_l_z + 1], - v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1], - v_idx[f1_l_x, f1_l_y + 1, f1_l_z], - ], axis=1) - - faces_1_r = np.stack([ - v_idx[f1_r_x, f1_r_y, f1_r_z], - v_idx[f1_r_x, f1_r_y + 1, f1_r_z], - v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1], - v_idx[f1_r_x, f1_r_y, f1_r_z + 1], - ], axis=1) - - faces_2_l = np.stack([ - v_idx[f2_l_x, f2_l_y, f2_l_z], - v_idx[f2_l_x + 1, f2_l_y, f2_l_z], - v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1], - v_idx[f2_l_x, f2_l_y, f2_l_z + 1], - ], axis=1) - - faces_2_r = np.stack([ - v_idx[f2_r_x, f2_r_y, f2_r_z], - v_idx[f2_r_x, f2_r_y, f2_r_z + 1], - v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1], - v_idx[f2_r_x + 1, f2_r_y, f2_r_z], - ], axis=1) - - faces_3_l = np.stack([ - v_idx[f3_l_x, f3_l_y, f3_l_z], - v_idx[f3_l_x, f3_l_y + 1, f3_l_z], - v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z], - v_idx[f3_l_x + 1, f3_l_y, f3_l_z], - ], axis=1) - - faces_3_r = np.stack([ - v_idx[f3_r_x, f3_r_y, f3_r_z], - v_idx[f3_r_x + 1, f3_r_y, f3_r_z], - v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z], - v_idx[f3_r_x, f3_r_y + 1, f3_r_z], - ], axis=1) - - faces = np.concatenate([ - faces_1_l, faces_1_r, - faces_2_l, faces_2_r, - faces_3_l, faces_3_r, - ], axis=0) + faces_1_l = np.stack( + [ + v_idx[f1_l_x, f1_l_y, f1_l_z], + v_idx[f1_l_x, f1_l_y, f1_l_z + 1], + v_idx[f1_l_x, f1_l_y + 1, f1_l_z + 1], + v_idx[f1_l_x, f1_l_y + 1, f1_l_z], + ], + axis=1 + ) + + faces_1_r = np.stack( + [ + v_idx[f1_r_x, f1_r_y, f1_r_z], + v_idx[f1_r_x, f1_r_y + 1, f1_r_z], + v_idx[f1_r_x, f1_r_y + 1, f1_r_z + 1], + v_idx[f1_r_x, f1_r_y, f1_r_z + 1], + ], + axis=1 + ) + + faces_2_l = np.stack( + [ + v_idx[f2_l_x, f2_l_y, f2_l_z], + v_idx[f2_l_x + 1, f2_l_y, f2_l_z], + v_idx[f2_l_x + 1, f2_l_y, f2_l_z + 1], + v_idx[f2_l_x, f2_l_y, f2_l_z + 1], + ], + axis=1 + ) + + faces_2_r = np.stack( + [ + v_idx[f2_r_x, f2_r_y, f2_r_z], + v_idx[f2_r_x, f2_r_y, f2_r_z + 1], + v_idx[f2_r_x + 1, f2_r_y, f2_r_z + 1], + v_idx[f2_r_x + 1, f2_r_y, f2_r_z], + ], + axis=1 + ) + + faces_3_l = np.stack( + [ + v_idx[f3_l_x, f3_l_y, f3_l_z], + v_idx[f3_l_x, f3_l_y + 1, f3_l_z], + v_idx[f3_l_x + 1, f3_l_y + 1, f3_l_z], + v_idx[f3_l_x + 1, f3_l_y, f3_l_z], + ], + axis=1 + ) + + faces_3_r = np.stack( + [ + v_idx[f3_r_x, f3_r_y, f3_r_z], + v_idx[f3_r_x + 1, f3_r_y, f3_r_z], + v_idx[f3_r_x + 1, f3_r_y + 1, f3_r_z], + v_idx[f3_r_x, f3_r_y + 1, f3_r_z], + ], + axis=1 + ) + + faces = np.concatenate( + [ + faces_1_l, + faces_1_r, + faces_2_l, + faces_2_r, + faces_3_l, + faces_3_r, + ], axis=0 + ) vertices = self.loc + self.scale * vertices mesh = trimesh.Trimesh(vertices, faces, process=False) @@ -200,7 +224,7 @@ class VoxelGrid: @property def resolution(self): - assert(self.data.shape[0] == self.data.shape[1] == self.data.shape[2]) + assert (self.data.shape[0] == self.data.shape[1] == self.data.shape[2]) return self.data.shape[0] def contains(self, points): @@ -211,12 +235,9 @@ class VoxelGrid: # Discretize points to [0, nx-1]^3 points_i = ((points + 0.5) * nx).astype(np.int32) # i1, i2, i3 have sizes (batch_size, T) - i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2] + i1, i2, i3 = points_i[..., 0], points_i[..., 1], points_i[..., 2] # Only use indices inside bounding box - mask = ( - (i1 >= 0) & (i2 >= 0) & (i3 >= 0) - & (nx > i1) & (nx > i2) & (nx > i3) - ) + mask = ((i1 >= 0) & (i2 >= 0) & (i3 >= 0) & (nx > i1) & (nx > i2) & (nx > i3)) # Prevent out of bounds error i1 = i1[mask] i2 = i2[mask] @@ -254,7 +275,7 @@ def voxelize_surface(mesh, resolution): vertices = (vertices + 0.5) * resolution face_loc = vertices[faces] - occ = np.full((resolution,) * 3, 0, dtype=np.int32) + occ = np.full((resolution, ) * 3, 0, dtype=np.int32) face_loc = face_loc.astype(np.float32) voxelize_mesh_(occ, face_loc) @@ -264,9 +285,9 @@ def voxelize_surface(mesh, resolution): def voxelize_interior(mesh, resolution): - shape = (resolution,) * 3 - bb_min = (0.5,) * 3 - bb_max = (resolution - 0.5,) * 3 + shape = (resolution, ) * 3 + bb_min = (0.5, ) * 3 + bb_max = (resolution - 0.5, ) * 3 # Create points. Add noise to break symmetry points = make_3d_grid(bb_min, bb_max, shape=shape).numpy() points = points + 0.1 * (np.random.rand(*points.shape) - 0.5) @@ -280,14 +301,9 @@ def check_voxel_occupied(occupancy_grid): occ = occupancy_grid occupied = ( - occ[..., :-1, :-1, :-1] - & occ[..., :-1, :-1, 1:] - & occ[..., :-1, 1:, :-1] - & occ[..., :-1, 1:, 1:] - & occ[..., 1:, :-1, :-1] - & occ[..., 1:, :-1, 1:] - & occ[..., 1:, 1:, :-1] - & occ[..., 1:, 1:, 1:] + occ[..., :-1, :-1, :-1] & occ[..., :-1, :-1, 1:] & occ[..., :-1, 1:, :-1] & + occ[..., :-1, 1:, 1:] & occ[..., 1:, :-1, :-1] & occ[..., 1:, :-1, 1:] & + occ[..., 1:, 1:, :-1] & occ[..., 1:, 1:, 1:] ) return occupied @@ -296,14 +312,9 @@ def check_voxel_unoccupied(occupancy_grid): occ = occupancy_grid unoccupied = ~( - occ[..., :-1, :-1, :-1] - | occ[..., :-1, :-1, 1:] - | occ[..., :-1, 1:, :-1] - | occ[..., :-1, 1:, 1:] - | occ[..., 1:, :-1, :-1] - | occ[..., 1:, :-1, 1:] - | occ[..., 1:, 1:, :-1] - | occ[..., 1:, 1:, 1:] + occ[..., :-1, :-1, :-1] | occ[..., :-1, :-1, 1:] | occ[..., :-1, 1:, :-1] | + occ[..., :-1, 1:, 1:] | occ[..., 1:, :-1, :-1] | occ[..., 1:, :-1, 1:] | + occ[..., 1:, 1:, :-1] | occ[..., 1:, 1:, 1:] ) return unoccupied diff --git a/lib/dataset/Evaluator.py b/lib/dataset/Evaluator.py index 6e3f1c7218d607174a58c6ac9f6406dbef3262d2..b215a9bb2f81b88029d63b7a83d8d76a842559e3 100644 --- a/lib/dataset/Evaluator.py +++ b/lib/dataset/Evaluator.py @@ -37,7 +37,6 @@ class _PointFaceDistance(Function): """ Torch autograd Function wrapper PointFaceDistance Cuda implementation """ - @staticmethod def forward( ctx, @@ -92,12 +91,15 @@ class _PointFaceDistance(Function): grad_dists = grad_dists.contiguous() points, tris, idxs = ctx.saved_tensors min_triangle_area = ctx.min_triangle_area - grad_points, grad_tris = _C.point_face_dist_backward(points, tris, idxs, grad_dists, min_triangle_area) + grad_points, grad_tris = _C.point_face_dist_backward( + points, tris, idxs, grad_dists, min_triangle_area + ) return grad_points, None, grad_tris, None, None, None -def _rand_barycentric_coords(size1, size2, dtype: torch.dtype, - device: torch.device) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def _rand_barycentric_coords( + size1, size2, dtype: torch.dtype, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Helper function to generate random barycentric coordinates which are uniformly distributed over a triangle. @@ -167,19 +169,21 @@ def sample_points_from_meshes(meshes, num_samples: int = 10000): faces = meshes.faces_packed() mesh_to_face = meshes.mesh_to_faces_packed_first_idx() num_meshes = len(meshes) - num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes. + num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes. # Initialize samples tensor with fill value 0 for empty meshes. samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device) # Only compute samples for non empty meshes with torch.no_grad(): - areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero. + areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero. max_faces = meshes.num_faces_per_mesh().max().item() - areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F) + areas_padded = packed_to_padded(areas, mesh_to_face[meshes.valid], max_faces) # (N, F) # TODO (gkioxari) Confirm multinomial bug is not present with real data. - samples_face_idxs = areas_padded.multinomial(num_samples, replacement=True) # (N, num_samples) + samples_face_idxs = areas_padded.multinomial( + num_samples, replacement=True + ) # (N, num_samples) samples_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1) # Randomly generate barycentric coords. @@ -200,23 +204,25 @@ def point_mesh_distance(meshes, pcls, weighted=True): raise ValueError("meshes and pointclouds must be equal sized batches") # packed representation for pointclouds - points = pcls.points_packed() # (P, 3) + points = pcls.points_packed() # (P, 3) points_first_idx = pcls.cloud_to_packed_first_idx() max_points = pcls.num_points_per_cloud().max().item() # packed representation for faces verts_packed = meshes.verts_packed() faces_packed = meshes.faces_packed() - tris = verts_packed[faces_packed] # (T, 3, 3) + tris = verts_packed[faces_packed] # (T, 3, 3) tris_first_idx = meshes.mesh_to_faces_packed_first_idx() # point to face distance: shape (P,) - point_to_face, idxs = _PointFaceDistance.apply(points, points_first_idx, tris, tris_first_idx, max_points, 5e-3) + point_to_face, idxs = _PointFaceDistance.apply( + points, points_first_idx, tris, tris_first_idx, max_points, 5e-3 + ) if weighted: # weight each example by the inverse of number of points in the example - point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) - num_points_per_cloud = pcls.num_points_per_cloud() # (N,) + point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) + num_points_per_cloud = pcls.num_points_per_cloud() # (N,) weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) weights_p = 1.0 / weights_p.float() point_to_face = torch.sqrt(point_to_face) * weights_p @@ -225,7 +231,6 @@ def point_mesh_distance(meshes, pcls, weighted=True): class Evaluator: - def __init__(self, device): self.render = Render(size=512, device=device) @@ -253,8 +258,8 @@ class Evaluator: self.render.meshes = self.tgt_mesh tgt_normal_imgs = self.render.get_image(cam_type="four", bg="black") - src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1] - tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1] + src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1] + tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4, padding=0) # [-1,1] src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) @@ -274,8 +279,11 @@ class Evaluator: # error_hf = ((((src_normal_arr - tgt_normal_arr) * sim_mask)**2).sum(dim=0).mean()) * 4.0 normal_img = Image.fromarray( - (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0).astype( - np.uint8)) + ( + torch.cat([src_normal_arr, tgt_normal_arr], + dim=1).permute(1, 2, 0).detach().cpu().numpy() * 255.0 + ).astype(np.uint8) + ) normal_img.save(normal_path) return error @@ -291,7 +299,9 @@ class Evaluator: p2s_dist_all, _ = point_mesh_distance(self.src_mesh, tgt_points) * 100.0 p2s_dist = p2s_dist_all.sum() - chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist) * 0.5 + chamfer_dist = ( + point_mesh_distance(self.tgt_mesh, src_points)[0].sum() * 100.0 + p2s_dist + ) * 0.5 return chamfer_dist, p2s_dist diff --git a/lib/dataset/NormalDataset.py b/lib/dataset/NormalDataset.py index 1e532b3c820885a8ea96ee65439796ad23de9230..3567ac8cd5a83517a93c80c008bbb9b8d23616a7 100644 --- a/lib/dataset/NormalDataset.py +++ b/lib/dataset/NormalDataset.py @@ -23,7 +23,6 @@ import torchvision.transforms as transforms class NormalDataset: - def __init__(self, cfg, split="train"): self.split = split @@ -44,8 +43,7 @@ class NormalDataset: if self.split != "train": self.rotations = range(0, 360, 120) else: - self.rotations = np.arange(0, 360, 360 // - self.opt.rotation_num).astype(np.int) + self.rotations = np.arange(0, 360, 360 // self.opt.rotation_num).astype(np.int) self.datasets_dict = {} @@ -54,26 +52,29 @@ class NormalDataset: dataset_dir = osp.join(self.root, dataset) self.datasets_dict[dataset] = { - "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), - dtype=str), + "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str), "scale": self.scales[dataset_id], } self.subject_list = self.get_subject_list(split) # PIL to tensor - self.image_to_tensor = transforms.Compose([ - transforms.Resize(self.input_size), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ]) + self.image_to_tensor = transforms.Compose( + [ + transforms.Resize(self.input_size), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) # PIL to tensor - self.mask_to_tensor = transforms.Compose([ - transforms.Resize(self.input_size), - transforms.ToTensor(), - transforms.Normalize((0.0, ), (1.0, )), - ]) + self.mask_to_tensor = transforms.Compose( + [ + transforms.Resize(self.input_size), + transforms.ToTensor(), + transforms.Normalize((0.0, ), (1.0, )), + ] + ) def get_subject_list(self, split): @@ -88,16 +89,12 @@ class NormalDataset: subject_list += np.loadtxt(split_txt, dtype=str).tolist() if self.split != "test": - subject_list += subject_list[:self.bsize - - len(subject_list) % self.bsize] + subject_list += subject_list[:self.bsize - len(subject_list) % self.bsize] print(colored(f"total: {len(subject_list)}", "yellow")) - bug_list = sorted( - np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist()) + bug_list = sorted(np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist()) - subject_list = [ - subject for subject in subject_list if (subject not in bug_list) - ] + subject_list = [subject for subject in subject_list if (subject not in bug_list)] # subject_list = ["thuman2/0008"] return subject_list @@ -113,48 +110,41 @@ class NormalDataset: rotation = self.rotations[rid] subject = self.subject_list[mid].split("/")[1] dataset = self.subject_list[mid].split("/")[0] - render_folder = "/".join( - [dataset + f"_{self.opt.rotation_num}views", subject]) + render_folder = "/".join([dataset + f"_{self.opt.rotation_num}views", subject]) if not osp.exists(osp.join(self.root, render_folder)): render_folder = "/".join([dataset + f"_36views", subject]) # setup paths data_dict = { - "dataset": - dataset, - "subject": - subject, - "rotation": - rotation, - "scale": - self.datasets_dict[dataset]["scale"], - "image_path": - osp.join(self.root, render_folder, "render", - f"{rotation:03d}.png"), + "dataset": dataset, + "subject": subject, + "rotation": rotation, + "scale": self.datasets_dict[dataset]["scale"], + "image_path": osp.join(self.root, render_folder, "render", f"{rotation:03d}.png"), } # image/normal/depth loader for name, channel in zip(self.in_total, self.in_total_dim): if f"{name}_path" not in data_dict.keys(): - data_dict.update({ - f"{name}_path": - osp.join(self.root, render_folder, name, - f"{rotation:03d}.png") - }) - - data_dict.update({ - name: - self.imagepath2tensor(data_dict[f"{name}_path"], - channel, - inv=False, - erasing=False) - }) - - path_keys = [ - key for key in data_dict.keys() if "_path" in key or "_dir" in key - ] + data_dict.update( + { + f"{name}_path": + osp.join(self.root, render_folder, name, f"{rotation:03d}.png") + } + ) + + data_dict.update( + { + name: + self.imagepath2tensor( + data_dict[f"{name}_path"], channel, inv=False, erasing=False + ) + } + ) + + path_keys = [key for key in data_dict.keys() if "_path" in key or "_dir" in key] for key in path_keys: del data_dict[key] @@ -172,10 +162,9 @@ class NormalDataset: # simulate occlusion if erasing: - mask = kornia.augmentation.RandomErasing(p=0.2, - scale=(0.01, 0.2), - ratio=(0.3, 3.3), - keepdim=True)(mask) + mask = kornia.augmentation.RandomErasing( + p=0.2, scale=(0.01, 0.2), ratio=(0.3, 3.3), keepdim=True + )(mask) image = (image * mask)[:channel] return (image * (0.5 - inv) * 2.0).float() diff --git a/lib/dataset/NormalModule.py b/lib/dataset/NormalModule.py index fbf0a4533f2e23db5ec25cd95511894b9f6296f9..ff672b3c42f5951f4ebf6c8446014d1d277ab02c 100644 --- a/lib/dataset/NormalModule.py +++ b/lib/dataset/NormalModule.py @@ -22,7 +22,6 @@ import pytorch_lightning as pl class NormalModule(pl.LightningDataModule): - def __init__(self, cfg): super(NormalModule, self).__init__() self.cfg = cfg @@ -40,7 +39,7 @@ class NormalModule(pl.LightningDataModule): self.train_dataset = NormalDataset(cfg=self.cfg, split="train") self.val_dataset = NormalDataset(cfg=self.cfg, split="val") self.test_dataset = NormalDataset(cfg=self.cfg, split="test") - + self.data_size = { "train": len(self.train_dataset), "val": len(self.val_dataset), @@ -69,7 +68,7 @@ class NormalModule(pl.LightningDataModule): ) return val_data_loader - + def val_dataloader(self): test_data_loader = DataLoader( diff --git a/lib/dataset/PointFeat.py b/lib/dataset/PointFeat.py index 4ade2b319aaa383c3d86cc319f013254d7ccebd9..457b949e5ce712a1eace33b1306fd48613ba8887 100644 --- a/lib/dataset/PointFeat.py +++ b/lib/dataset/PointFeat.py @@ -6,7 +6,6 @@ from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection class PointFeat: - def __init__(self, verts, faces): # verts [B, N_vert, 3] @@ -23,7 +22,10 @@ class PointFeat: if verts.shape[1] == 10475: faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] - mouth_faces = (torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1, 1).to(self.device)) + mouth_faces = ( + torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1, + 1).to(self.device) + ) self.faces = torch.cat([faces, mouth_faces], dim=1).long() self.verts = verts.float() @@ -35,11 +37,15 @@ class PointFeat: points = points.float() residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False) - closest_triangles = torch.gather(self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) + closest_triangles = torch.gather( + self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) + ).view(-1, 3, 3) bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles) feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces) - closest_normals = torch.gather(feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) + closest_normals = torch.gather( + feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) + ).view(-1, 3, 3) shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0)) pts2shoot_normals = points - shoot_verts diff --git a/lib/dataset/TestDataset.py b/lib/dataset/TestDataset.py index 49d6187bbfd131ad95cf63d5f59984e26acb1071..e016d3d94b3d9f043ef5a8526d2d1be67be6f4a7 100644 --- a/lib/dataset/TestDataset.py +++ b/lib/dataset/TestDataset.py @@ -25,6 +25,7 @@ from lib.pixielib.utils.config import cfg as pixie_cfg from lib.pixielib.pixie import PIXIE from lib.pixielib.models.SMPLX import SMPLX as PIXIE_SMPLX from lib.common.imutils import process_image +from lib.common.train_util import Format from lib.net.geometry import rotation_matrix_to_angle_axis, rot6d_to_rotmat from lib.pymafx.core import path_config @@ -36,8 +37,9 @@ from lib.dataset.body_model import TetraSMPLModel from lib.dataset.mesh_util import get_visibility, SMPLX import torch.nn.functional as F from torchvision import transforms +from torchvision.models import detection + import os.path as osp -import os import torch import glob import numpy as np @@ -48,7 +50,6 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True class TestDataset: - def __init__(self, cfg, device): self.image_dir = cfg["image_dir"] @@ -65,7 +66,9 @@ class TestDataset: keep_lst = sorted(glob.glob(f"{self.image_dir}/*")) img_fmts = ["jpg", "png", "jpeg", "JPG", "bmp"] - self.subject_list = sorted([item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False) + self.subject_list = sorted( + [item for item in keep_lst if item.split(".")[-1] in img_fmts], reverse=False + ) # smpl related self.smpl_data = SMPLX() @@ -80,7 +83,16 @@ class TestDataset: self.smpl_model = PIXIE_SMPLX(pixie_cfg.model).to(self.device) - print(colored(f"Use {self.hps_type.upper()} to estimate human pose and shape", "green")) + self.detector = detection.maskrcnn_resnet50_fpn( + weights=detection.MaskRCNN_ResNet50_FPN_V2_Weights + ) + self.detector.eval() + + print( + colored( + f"SMPL-X estimate with {Format.start} {self.hps_type.upper()} {Format.end}", "green" + ) + ) self.render = Render(size=512, device=self.device) @@ -90,7 +102,9 @@ class TestDataset: def compute_vis_cmap(self, smpl_verts, smpl_faces): (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=-1) - smpl_vis = get_visibility(xy, z, torch.as_tensor(smpl_faces).long()[:, :, [0, 2, 1]]).unsqueeze(-1) + smpl_vis = get_visibility(xy, z, + torch.as_tensor(smpl_faces).long()[:, :, + [0, 2, 1]]).unsqueeze(-1) smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type).unsqueeze(0) return { @@ -109,7 +123,8 @@ class TestDataset: depth_FB[:, ~depth_mask[0]] = 0. # Important: index_long = depth_value - 1 - index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res - 1).permute(1, 2, 0) + index_z = (((depth_FB + 1.) * 0.5 * self.vol_res) - 1).clip(0, self.vol_res - + 1).permute(1, 2, 0) index_z_ceil = torch.ceil(index_z).long() index_z_floor = torch.floor(index_z).long() index_z_frac = torch.frac(index_z) @@ -121,7 +136,7 @@ class TestDataset: F.one_hot(index_z_floor[..., 1], self.vol_res) * (1.0 - index_z_frac[..., 1]) voxels[index_mask] *= 0 - voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float() #[x-2, y-0, z-1] + voxels = torch.flip(voxels, [2]).permute(2, 0, 1).float() #[x-2, y-0, z-1] return { "depth_voxels": voxels.flip([ @@ -139,18 +154,25 @@ class TestDataset: smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0]) verts = ( - np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() + trans.detach().cpu().numpy()) + np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * scale.item() + + trans.detach().cpu().numpy() + ) faces = ( np.loadtxt( osp.join(self.smpl_data.tedra_dir, "tetrahedrons_neutral_adult.txt"), dtype=np.int32, - ) - 1) + ) - 1 + ) pad_v_num = int(8000 - verts.shape[0]) pad_f_num = int(25100 - faces.shape[0]) - verts = (np.pad(verts, ((0, pad_v_num), (0, 0)), mode="constant", constant_values=0.0).astype(np.float32) * 0.5) - faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant", constant_values=0.0).astype(np.int32) + verts = ( + np.pad(verts, ((0, pad_v_num), + (0, 0)), mode="constant", constant_values=0.0).astype(np.float32) * 0.5 + ) + faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode="constant", + constant_values=0.0).astype(np.int32) verts[:, 2] *= -1.0 @@ -168,7 +190,7 @@ class TestDataset: img_path = self.subject_list[index] img_name = img_path.split("/")[-1].rsplit(".", 1)[0] - arr_dict = process_image(img_path, self.hps_type, self.single, 512) + arr_dict = process_image(img_path, self.hps_type, self.single, 512, self.detector) arr_dict.update({"name": img_name}) with torch.no_grad(): @@ -179,7 +201,10 @@ class TestDataset: preds_dict, _ = self.hps.forward(batch) arr_dict["smpl_faces"] = ( - torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to(self.device)) + torch.as_tensor(self.smpl_data.smplx_faces.astype(np.int64)).unsqueeze(0).long().to( + self.device + ) + ) arr_dict["type"] = self.smpl_type if self.hps_type == "pymafx": @@ -198,13 +223,16 @@ class TestDataset: elif self.hps_type == "pixie": arr_dict.update(preds_dict) arr_dict["global_orient"] = preds_dict["global_pose"] - arr_dict["betas"] = preds_dict["shape"] #200 + arr_dict["betas"] = preds_dict["shape"] #200 arr_dict["smpl_verts"] = preds_dict["vertices"] scale, tranX, tranY = preds_dict["cam"].split(1, dim=1) # 1.1435, 0.0128, 0.3520 arr_dict["scale"] = scale.unsqueeze(1) - arr_dict["trans"] = (torch.cat([tranX, tranY, torch.zeros_like(tranX)], dim=1).unsqueeze(1).to(self.device).float()) + arr_dict["trans"] = ( + torch.cat([tranX, tranY, torch.zeros_like(tranX)], + dim=1).unsqueeze(1).to(self.device).float() + ) # data_dict info (key-shape): # scale, tranX, tranY - tensor.float @@ -230,4 +258,4 @@ class TestDataset: # render optimized mesh (normal, T_normal, image [-1,1]) self.render.load_meshes(verts, faces) - return self.render.get_image(type="depth") \ No newline at end of file + return self.render.get_image(type="depth") diff --git a/lib/dataset/body_model.py b/lib/dataset/body_model.py index f41f481ae67e124cbb45c5c8f34179c5c6f49311..cebb105591cab29d833f2965ec609c85fd522881 100644 --- a/lib/dataset/body_model.py +++ b/lib/dataset/body_model.py @@ -21,7 +21,6 @@ import os class SMPLModel: - def __init__(self, model_path, age): """ SMPL model. @@ -49,20 +48,16 @@ class SMPLModel: if age == "kid": v_template_smil = np.load( - os.path.join(os.path.dirname(model_path), - "smpl/smpl_kid_template.npy")) + os.path.join(os.path.dirname(model_path), "smpl/smpl_kid_template.npy") + ) v_template_smil -= np.mean(v_template_smil, axis=0) - v_template_diff = np.expand_dims(v_template_smil - self.v_template, - axis=2) + v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2) self.shapedirs = np.concatenate( - (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), - axis=2) + (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2 + ) self.beta_shape[0] += 1 - id_to_col = { - self.kintree_table[1, i]: i - for i in range(self.kintree_table.shape[1]) - } + id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])} self.parent = { i: id_to_col[self.kintree_table[0, i]] for i in range(1, self.kintree_table.shape[1]) @@ -121,33 +116,30 @@ class SMPLModel: pose_cube = self.pose.reshape((-1, 1, 3)) # rotation matrix for each joint self.R = self.rodrigues(pose_cube) - I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), - (self.R.shape[0] - 1, 3, 3)) + I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3)) lrotmin = (self.R[1:] - I_cube).ravel() # how pose affect body shape in zero pose v_posed = v_shaped + self.posedirs.dot(lrotmin) # world transformation of each joint G = np.empty((self.kintree_table.shape[1], 4, 4)) - G[0] = self.with_zeros( - np.hstack((self.R[0], self.J[0, :].reshape([3, 1])))) + G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1])))) for i in range(1, self.kintree_table.shape[1]): G[i] = G[self.parent[i]].dot( self.with_zeros( - np.hstack([ - self.R[i], - ((self.J[i, :] - self.J[self.parent[i], :]).reshape( - [3, 1])), - ]))) + np.hstack( + [ + self.R[i], + ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])), + ] + ) + ) + ) # remove the transformation due to the rest pose - G = G - self.pack( - np.matmul( - G, - np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1]))) + G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1]))) # transformation of each vertex T = np.tensordot(self.weights, G, axes=[[1], [0]]) rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1]))) - v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, - 4])[:, :3] + v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3] self.verts = v + self.trans.reshape([1, 3]) self.G = G @@ -171,19 +163,20 @@ class SMPLModel: r_hat = r / theta cos = np.cos(theta) z_stick = np.zeros(theta.shape[0]) - m = np.dstack([ - z_stick, - -r_hat[:, 0, 2], - r_hat[:, 0, 1], - r_hat[:, 0, 2], - z_stick, - -r_hat[:, 0, 0], - -r_hat[:, 0, 1], - r_hat[:, 0, 0], - z_stick, - ]).reshape([-1, 3, 3]) - i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), - [theta.shape[0], 3, 3]) + m = np.dstack( + [ + z_stick, + -r_hat[:, 0, 2], + r_hat[:, 0, 1], + r_hat[:, 0, 2], + z_stick, + -r_hat[:, 0, 0], + -r_hat[:, 0, 1], + r_hat[:, 0, 0], + z_stick, + ] + ).reshape([-1, 3, 3]) + i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3]) A = np.transpose(r_hat, axes=[0, 2, 1]) B = r_hat dot = np.matmul(A, B) @@ -238,12 +231,7 @@ class SMPLModel: class TetraSMPLModel: - - def __init__(self, - model_path, - model_addition_path, - age="adult", - v_template=None): + def __init__(self, model_path, model_addition_path, age="adult", v_template=None): """ SMPL model. @@ -276,10 +264,7 @@ class TetraSMPLModel: self.posedirs_added = params_added["posedirs_added"] self.tetrahedrons = params_added["tetrahedrons"] - id_to_col = { - self.kintree_table[1, i]: i - for i in range(self.kintree_table.shape[1]) - } + id_to_col = {self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])} self.parent = { i: id_to_col[self.kintree_table[0, i]] for i in range(1, self.kintree_table.shape[1]) @@ -291,14 +276,13 @@ class TetraSMPLModel: if age == "kid": v_template_smil = np.load( - os.path.join(os.path.dirname(model_path), - "smpl_kid_template.npy")) + os.path.join(os.path.dirname(model_path), "smpl_kid_template.npy") + ) v_template_smil -= np.mean(v_template_smil, axis=0) - v_template_diff = np.expand_dims(v_template_smil - self.v_template, - axis=2) + v_template_diff = np.expand_dims(v_template_smil - self.v_template, axis=2) self.shapedirs = np.concatenate( - (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), - axis=2) + (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff), axis=2 + ) self.beta_shape[0] += 1 self.pose = np.zeros(self.pose_shape) @@ -356,50 +340,42 @@ class TetraSMPLModel: """ # how beta affect body shape v_shaped = self.shapedirs.dot(self.beta) + self.v_template - v_shaped_added = self.shapedirs_added.dot( - self.beta) + self.v_template_added + v_shaped_added = self.shapedirs_added.dot(self.beta) + self.v_template_added # joints location self.J = self.J_regressor.dot(v_shaped) pose_cube = self.pose.reshape((-1, 1, 3)) # rotation matrix for each joint self.R = self.rodrigues(pose_cube) - I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), - (self.R.shape[0] - 1, 3, 3)) + I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3)) lrotmin = (self.R[1:] - I_cube).ravel() # how pose affect body shape in zero pose v_posed = v_shaped + self.posedirs.dot(lrotmin) v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin) # world transformation of each joint G = np.empty((self.kintree_table.shape[1], 4, 4)) - G[0] = self.with_zeros( - np.hstack((self.R[0], self.J[0, :].reshape([3, 1])))) + G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1])))) for i in range(1, self.kintree_table.shape[1]): G[i] = G[self.parent[i]].dot( self.with_zeros( - np.hstack([ - self.R[i], - ((self.J[i, :] - self.J[self.parent[i], :]).reshape( - [3, 1])), - ]))) + np.hstack( + [ + self.R[i], + ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1])), + ] + ) + ) + ) # remove the transformation due to the rest pose - G = G - self.pack( - np.matmul( - G, - np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1]))) + G = G - self.pack(np.matmul(G, np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1]))) self.G = G # transformation of each vertex T = np.tensordot(self.weights, G, axes=[[1], [0]]) rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1]))) - v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, - 4])[:, :3] + v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3] self.verts = v + self.trans.reshape([1, 3]) T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]]) - rest_shape_added_h = np.hstack( - (v_posed_added, np.ones([v_posed_added.shape[0], 1]))) - v_added = np.matmul(T_added, - rest_shape_added_h.reshape([-1, 4, - 1])).reshape([-1, 4 - ])[:, :3] + rest_shape_added_h = np.hstack((v_posed_added, np.ones([v_posed_added.shape[0], 1]))) + v_added = np.matmul(T_added, rest_shape_added_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3] self.verts_added = v_added + self.trans.reshape([1, 3]) def rodrigues(self, r): @@ -422,19 +398,20 @@ class TetraSMPLModel: r_hat = r / theta cos = np.cos(theta) z_stick = np.zeros(theta.shape[0]) - m = np.dstack([ - z_stick, - -r_hat[:, 0, 2], - r_hat[:, 0, 1], - r_hat[:, 0, 2], - z_stick, - -r_hat[:, 0, 0], - -r_hat[:, 0, 1], - r_hat[:, 0, 0], - z_stick, - ]).reshape([-1, 3, 3]) - i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), - [theta.shape[0], 3, 3]) + m = np.dstack( + [ + z_stick, + -r_hat[:, 0, 2], + r_hat[:, 0, 1], + r_hat[:, 0, 2], + z_stick, + -r_hat[:, 0, 0], + -r_hat[:, 0, 1], + r_hat[:, 0, 0], + z_stick, + ] + ).reshape([-1, 3, 3]) + i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3]) A = np.transpose(r_hat, axes=[0, 2, 1]) B = r_hat dot = np.matmul(A, B) diff --git a/lib/dataset/mesh_util.py b/lib/dataset/mesh_util.py index 52496a221fc995f891abdebe892acbf2f701aeca..b639763e5ea98fd79698878a47c3089b08f4b86e 100644 --- a/lib/dataset/mesh_util.py +++ b/lib/dataset/mesh_util.py @@ -14,32 +14,33 @@ # # Contact: ps-license@tuebingen.mpg.de +import os import numpy as np -import cv2 -import pymeshlab import torch import torchvision import trimesh -import os -from termcolor import colored +import open3d as o3d +import tinyobjloader import os.path as osp import _pickle as cPickle +from termcolor import colored from scipy.spatial import cKDTree from pytorch3d.structures import Meshes import torch.nn.functional as F import lib.smplx as smplx +from lib.common.render_utils import Pytorch3dRasterizer from pytorch3d.renderer.mesh import rasterize_meshes from PIL import Image, ImageFont, ImageDraw from pytorch3d.loss import mesh_laplacian_smoothing, mesh_normal_consistency -import tinyobjloader -from lib.common.imutils import uncrop -from lib.common.render_utils import Pytorch3dRasterizer +class Format: + end = '\033[0m' + start = '\033[4m' -class SMPLX: +class SMPLX: def __init__(self): self.current_dir = osp.join(osp.dirname(__file__), "../../data/smpl_related") @@ -54,10 +55,14 @@ class SMPLX: self.smplx_eyeball_fid_path = osp.join(self.current_dir, "smpl_data/eyeball_fid.npy") self.smplx_fill_mouth_fid_path = osp.join(self.current_dir, "smpl_data/fill_mouth_fid.npy") - self.smplx_flame_vid_path = osp.join(self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy") + self.smplx_flame_vid_path = osp.join( + self.current_dir, "smpl_data/FLAME_SMPLX_vertex_ids.npy" + ) self.smplx_mano_vid_path = osp.join(self.current_dir, "smpl_data/MANO_SMPLX_vertex_ids.pkl") self.front_flame_path = osp.join(self.current_dir, "smpl_data/FLAME_face_mask_ids.npy") - self.smplx_vertex_lmkid_path = osp.join(self.current_dir, "smpl_data/smplx_vertex_lmkid.npy") + self.smplx_vertex_lmkid_path = osp.join( + self.current_dir, "smpl_data/smplx_vertex_lmkid.npy" + ) self.smplx_faces = np.load(self.smplx_faces_path) self.smplx_verts = np.load(self.smplx_verts_path) @@ -68,84 +73,51 @@ class SMPLX: self.smplx_eyeball_fid_mask = np.load(self.smplx_eyeball_fid_path) self.smplx_mouth_fid = np.load(self.smplx_fill_mouth_fid_path) self.smplx_mano_vid_dict = np.load(self.smplx_mano_vid_path, allow_pickle=True) - self.smplx_mano_vid = np.concatenate([self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]]) + self.smplx_mano_vid = np.concatenate( + [self.smplx_mano_vid_dict["left_hand"], self.smplx_mano_vid_dict["right_hand"]] + ) self.smplx_flame_vid = np.load(self.smplx_flame_vid_path, allow_pickle=True) self.smplx_front_flame_vid = self.smplx_flame_vid[np.load(self.front_flame_path)] # hands - self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_(0, torch.tensor(self.smplx_mano_vid), 1.0) + self.mano_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_mano_vid), 1.0 + ) # face - self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_( - 0, torch.tensor(self.smplx_front_flame_vid), 1.0) - self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0],).index_fill_( - 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0) + self.front_flame_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_front_flame_vid), 1.0 + ) + self.eyeball_vertex_mask = torch.zeros(self.smplx_verts.shape[0], ).index_fill_( + 0, torch.tensor(self.smplx_faces[self.smplx_eyeball_fid_mask].flatten()), 1.0 + ) self.smplx_to_smpl = cPickle.load(open(self.smplx_to_smplx_path, "rb")) self.model_dir = osp.join(self.current_dir, "models") self.tedra_dir = osp.join(self.current_dir, "../tedra_data") - self.ghum_smpl_pairs = torch.tensor([ - (0, 24), - (2, 26), - (5, 25), - (7, 28), - (8, 27), - (11, 16), - (12, 17), - (13, 18), - (14, 19), - (15, 20), - (16, 21), - (17, 39), - (18, 44), - (19, 36), - (20, 41), - (21, 35), - (22, 40), - (23, 1), - (24, 2), - (25, 4), - (26, 5), - (27, 7), - (28, 8), - (29, 31), - (30, 34), - (31, 29), - (32, 32), - ]).long() + self.ghum_smpl_pairs = torch.tensor( + [ + (0, 24), (2, 26), (5, 25), (7, 28), (8, 27), (11, 16), (12, 17), (13, 18), (14, 19), + (15, 20), (16, 21), (17, 39), (18, 44), (19, 36), (20, 41), (21, 35), (22, 40), + (23, 1), (24, 2), (25, 4), (26, 5), (27, 7), (28, 8), (29, 31), (30, 34), (31, 29), + (32, 32) + ] + ).long() # smpl-smplx correspondence self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73] self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [61 + 68, 72 + 68] - self.smpl_joint_ids_45 = (np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist()) - - self.extra_joint_ids = ( - np.array([ - 61, - 72, - 66, - 69, - 58, - 68, - 57, - 56, - 64, - 59, - 67, - 75, - 70, - 65, - 60, - 61, - 63, - 62, - 76, - 71, - 72, - 74, - 73, - ]) + 68) + self.smpl_joint_ids_45 = np.arange(22).tolist() + [68, 73] + np.arange(55, 76).tolist() + + self.extra_joint_ids = np.array( + [ + 61, 72, 66, 69, 58, 68, 57, 56, 64, 59, 67, 75, 70, 65, 60, 61, 63, 62, 76, 71, 72, + 74, 73 + ] + ) + + self.extra_joint_ids += 68 self.smpl_joint_ids_45_pixie = (np.arange(22).tolist() + self.extra_joint_ids.tolist()) @@ -222,27 +194,6 @@ def load_fit_body(fitted_path, scale, smpl_type="smplx", smpl_gender="neutral", return smpl_mesh, smpl_joints -def create_grid_points_from_xyz_bounds(bound, res): - - min_x, max_x, min_y, max_y, min_z, max_z = bound - x = torch.linspace(min_x, max_x, res) - y = torch.linspace(min_y, max_y, res) - z = torch.linspace(min_z, max_z, res) - X, Y, Z = torch.meshgrid(x, y, z, indexing='ij') - - return torch.stack([X, Y, Z], dim=-1) - - -def create_grid_points_from_xy_bounds(bound, res): - - min_x, max_x, min_y, max_y = bound - x = torch.linspace(min_x, max_x, res) - y = torch.linspace(min_y, max_y, res) - X, Y = torch.meshgrid(x, y, indexing='ij') - - return torch.stack([X, Y], dim=-1) - - def apply_face_mask(mesh, face_mask): mesh.update_faces(face_mask) @@ -277,7 +228,8 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr part_extractor = PointFeat( torch.tensor(part_mesh.vertices).unsqueeze(0).to(device), - torch.tensor(part_mesh.faces).unsqueeze(0).to(device)) + torch.tensor(part_mesh.faces).unsqueeze(0).to(device) + ) (part_dist, _) = part_extractor.query(torch.tensor(full_mesh.vertices).unsqueeze(0).to(device)) @@ -286,12 +238,20 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr if region == "hand": _, idx = smpl_tree.query(full_mesh.vertices, k=1) full_lmkid = SMPL_container.smplx_vertex_lmkid[idx] - remove_mask = torch.logical_and(remove_mask, torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0)) + remove_mask = torch.logical_and( + remove_mask, + torch.tensor(full_lmkid >= 20).type_as(remove_mask).unsqueeze(0) + ) elif region == "face": _, idx = smpl_tree.query(full_mesh.vertices, k=5) - face_space_mask = torch.isin(torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid)) - remove_mask = torch.logical_and(remove_mask, face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0)) + face_space_mask = torch.isin( + torch.tensor(idx), torch.tensor(SMPL_container.smplx_front_flame_vid) + ) + remove_mask = torch.logical_and( + remove_mask, + face_space_mask.any(dim=1).type_as(remove_mask).unsqueeze(0) + ) BNI_part_mask = ~(remove_mask).flatten()[full_mesh.faces].any(dim=1) full_mesh.update_faces(BNI_part_mask.detach().cpu()) @@ -303,109 +263,6 @@ def part_removal(full_mesh, part_mesh, thres, device, smpl_obj, region, clean=Tr return full_mesh -def cross(triangles): - """ - Returns the cross product of two edges from input triangles - Parameters - -------------- - triangles: (n, 3, 3) float - Vertices of triangles - Returns - -------------- - crosses : (n, 3) float - Cross product of two edge vectors - """ - vectors = np.diff(triangles, axis=1) - crosses = np.cross(vectors[:, 0], vectors[:, 1]) - return crosses - - -def tri_area(triangles=None, crosses=None, sum=False): - """ - Calculates the sum area of input triangles - Parameters - ---------- - triangles : (n, 3, 3) float - Vertices of triangles - crosses : (n, 3) float or None - As a speedup don't re- compute cross products - sum : bool - Return summed area or individual triangle area - Returns - ---------- - area : (n,) float or float - Individual or summed area depending on `sum` argument - """ - if crosses is None: - crosses = cross(triangles) - area = (np.sum(crosses**2, axis=1)**.5) * .5 - if sum: - return np.sum(area) - return area - - -def sample_surface(triangles, count, area=None): - """ - Sample the surface of a mesh, returning the specified - number of points - For individual triangle sampling uses this method: - http://mathworld.wolfram.com/TrianglePointPicking.html - Parameters - --------- - triangles : (n, 3, 3) float - Vertices of triangles - count : int - Number of points to return - Returns - --------- - samples : (count, 3) float - Points in space on the surface of mesh - face_index : (count,) int - Indices of faces for each sampled point - """ - - # len(mesh.faces) float, array of the areas - # of each face of the mesh - if area is None: - area = tri_area(triangles) - - # total area (float) - area_sum = np.sum(area) - # cumulative area (len(mesh.faces)) - area_cum = np.cumsum(area) - face_pick = np.random.random(count) * area_sum - face_index = np.searchsorted(area_cum, face_pick) - - # pull triangles into the form of an origin + 2 vectors - tri_origins = triangles[:, 0] - tri_vectors = triangles[:, 1:].copy() - tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) - - # pull the vectors for the faces we are going to sample from - tri_origins = tri_origins[face_index] - tri_vectors = tri_vectors[face_index] - - # randomly generate two 0-1 scalar components to multiply edge vectors by - random_lengths = np.random.random((len(tri_vectors), 2, 1)) - - # points will be distributed on a quadrilateral if we use 2 0-1 samples - # if the two scalar components sum less than 1.0 the point will be - # inside the triangle, so we find vectors longer than 1.0 and - # transform them to be inside the triangle - random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0 - random_lengths[random_test] -= 1.0 - random_lengths = np.abs(random_lengths) - - # multiply triangle edge vectors by the random lengths and sum - sample_vector = (tri_vectors * random_lengths).sum(axis=1) - - # finally, offset by the origin to generate - # (n,3) points in space on the triangle - samples = torch.tensor(sample_vector + tri_origins).float() - - return samples, face_index - - def obj_loader(path, with_uv=True): # Create reader. reader = tinyobjloader.ObjReader() @@ -424,8 +281,8 @@ def obj_loader(path, with_uv=True): f_vt = tri[:, [2, 5, 8]] if with_uv: - face_uvs = vt[f_vt].mean(axis=1) #[m, 2] - vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32) #[n, 2] + face_uvs = vt[f_vt].mean(axis=1) #[m, 2] + vert_uvs = np.zeros((v.shape[0], 2), dtype=np.float32) #[n, 2] vert_uvs[f_v.reshape(-1)] = vt[f_vt.reshape(-1)] return v, f_v, vert_uvs, face_uvs @@ -434,7 +291,6 @@ def obj_loader(path, with_uv=True): class HoppeMesh: - def __init__(self, verts, faces, uvs=None, texture=None): """ The HoppeSDF calculates signed distance towards a predefined oriented point cloud @@ -459,34 +315,20 @@ class HoppeMesh: - points: [n, 3] - return: [n, 4] rgba """ - triangles = self.verts[faces] #[n, 3, 3] - barycentric = trimesh.triangles.points_to_barycentric(triangles, points) #[n, 3] - vert_colors = self.vertex_colors[faces] #[n, 3, 4] + triangles = self.verts[faces] #[n, 3, 3] + barycentric = trimesh.triangles.points_to_barycentric(triangles, points) #[n, 3] + vert_colors = self.vertex_colors[faces] #[n, 3, 4] point_colors = torch.tensor((barycentric[:, :, None] * vert_colors).sum(axis=1)).float() return point_colors def triangles(self): - return self.verts[self.faces].numpy() #[n, 3, 3] + return self.verts[self.faces].numpy() #[n, 3, 3] def tensor2variable(tensor, device): return tensor.requires_grad_(True).to(device) -class GMoF(torch.nn.Module): - - def __init__(self, rho=1): - super(GMoF, self).__init__() - self.rho = rho - - def extra_repr(self): - return "rho = {}".format(self.rho) - - def forward(self, residual): - dist = torch.div(residual, residual + self.rho**2) - return self.rho**2 * dist - - def mesh_edge_loss(meshes, target_length: float = 0.0): """ Computes mesh edge length regularization loss averaged across all meshes @@ -508,10 +350,10 @@ def mesh_edge_loss(meshes, target_length: float = 0.0): return torch.tensor([0.0], dtype=torch.float32, device=meshes.device, requires_grad=True) N = len(meshes) - edges_packed = meshes.edges_packed() # (sum(E_n), 3) - verts_packed = meshes.verts_packed() # (sum(V_n), 3) - edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), ) - num_edges_per_mesh = meshes.num_edges_per_mesh() # N + edges_packed = meshes.edges_packed() # (sum(E_n), 3) + verts_packed = meshes.verts_packed() # (sum(V_n), 3) + edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), ) + num_edges_per_mesh = meshes.num_edges_per_mesh() # N # Determine the weight for each edge based on the number of edges in the # mesh it corresponds to. @@ -531,99 +373,37 @@ def mesh_edge_loss(meshes, target_length: float = 0.0): return loss_all -def remesh(obj, obj_path): - - obj.export(obj_path) - ms = pymeshlab.MeshSet() - ms.load_new_mesh(obj_path) - # ms.meshing_decimation_quadric_edge_collapse(targetfacenum=100000) - ms.meshing_isotropic_explicit_remeshing(targetlen=pymeshlab.Percentage(0.5), adaptive=True) - ms.apply_coord_laplacian_smoothing() - ms.save_current_mesh(obj_path[:-4] + "_remesh.obj") - polished_mesh = trimesh.load_mesh(obj_path[:-4] + "_remesh.obj") +def remesh_laplacian(mesh, obj_path): - return polished_mesh - - -def poisson_remesh(obj_path): - - ms = pymeshlab.MeshSet() - ms.load_new_mesh(obj_path) - ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000) - # ms.apply_coord_laplacian_smoothing() - ms.save_current_mesh(obj_path) - # ms.save_current_mesh(obj_path.replace(".obj", ".ply")) - polished_mesh = trimesh.load_mesh(obj_path) + mesh = mesh.simplify_quadratic_decimation(50000) + mesh = trimesh.smoothing.filter_humphrey( + mesh, alpha=0.1, beta=0.5, iterations=10, laplacian_operator=None + ) + mesh.export(obj_path) - return polished_mesh + return mesh def poisson(mesh, obj_path, depth=10): - from pypoisson import poisson_reconstruction - faces, vertices = poisson_reconstruction(mesh.vertices, mesh.vertex_normals, depth=depth) - - new_meshes = trimesh.Trimesh(vertices, faces) - new_mesh_lst = new_meshes.split(only_watertight=False) - comp_num = [new_mesh.vertices.shape[0] for new_mesh in new_mesh_lst] - final_mesh = new_mesh_lst[comp_num.index(max(comp_num))] - final_mesh.export(obj_path) + pcd_path = obj_path[:-4] + ".ply" + assert (mesh.vertex_normals.shape[1] == 3) + mesh.export(pcd_path) + pcl = o3d.io.read_point_cloud(pcd_path) + with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm: + mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( + pcl, depth=depth, n_threads=-1 + ) + print(colored(f"\n Poisson completion to {Format.start} {obj_path} {Format.end}", "yellow")) - final_mesh = poisson_remesh(obj_path) + # only keep the largest component + largest_mesh = keep_largest(trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles))) + largest_mesh.export(obj_path) - return final_mesh + # mesh decimation for faster rendering + low_res_mesh = largest_mesh.simplify_quadratic_decimation(50000) - -def get_mask(tensor, dim): - - mask = torch.abs(tensor).sum(dim=dim, keepdims=True) > 0.0 - mask = mask.type_as(tensor) - - return mask - - -def blend_rgb_norm(norms, data): - - # norms [N, 3, res, res] - - masks = (norms.sum(dim=1) != norms[0, :, 0, 0].sum()).float().unsqueeze(1) - norm_mask = F.interpolate( - torch.cat([norms, masks], dim=1).detach().cpu(), - size=data["uncrop_param"]["box_shape"], - mode="bilinear", - align_corners=False).permute(0, 2, 3, 1).numpy() - final = data["img_raw"] - - for idx in range(len(norms)): - - norm_pred = (norm_mask[idx, :, :, :3] + 1.0) * 255.0 / 2.0 - mask_pred = np.repeat(norm_mask[idx, :, :, 3:4], 3, axis=-1) - - norm_ori = unwrap(norm_pred, data["uncrop_param"], idx) - mask_ori = unwrap(mask_pred, data["uncrop_param"], idx) - - final = final * (1.0 - mask_ori) + norm_ori * mask_ori - - return final.astype(np.uint8) - - -def unwrap(image, uncrop_param, idx): - - img_uncrop = uncrop( - image, - uncrop_param["center"][idx], - uncrop_param["scale"][idx], - uncrop_param["crop_shape"], - ) - - img_orig = cv2.warpAffine( - img_uncrop, - np.linalg.inv(uncrop_param["M"])[:2, :], - uncrop_param["ori_shape"][::-1], - flags=cv2.INTER_CUBIC, - ) - - return img_orig + return low_res_mesh # Losses to smooth / regularize the mesh shape @@ -634,60 +414,7 @@ def update_mesh_shape_prior_losses(mesh, losses): # mesh normal consistency losses["nc"]["value"] = mesh_normal_consistency(mesh) # mesh laplacian smoothing - losses["laplacian"]["value"] = mesh_laplacian_smoothing(mesh, method="uniform") - - -def rename(old_dict, old_name, new_name): - new_dict = {} - for key, value in zip(old_dict.keys(), old_dict.values()): - new_key = key if key != old_name else new_name - new_dict[new_key] = old_dict[key] - return new_dict - - -def load_checkpoint(model, cfg): - - model_dict = model.state_dict() - main_dict = {} - normal_dict = {} - - device = torch.device(f"cuda:{cfg['test_gpus'][0]}") - - if os.path.exists(cfg.resume_path) and cfg.resume_path.endswith("ckpt"): - main_dict = torch.load(cfg.resume_path, map_location=device)["state_dict"] - - main_dict = { - k: v for k, v in main_dict.items() if k in model_dict and v.shape == model_dict[k].shape and - ("reconEngine" not in k) and ("normal_filter" not in k) and ("voxelization" not in k) - } - print(colored(f"Resume MLP weights from {cfg.resume_path}", "green")) - - if os.path.exists(cfg.normal_path) and cfg.normal_path.endswith("ckpt"): - normal_dict = torch.load(cfg.normal_path, map_location=device)["state_dict"] - - for key in normal_dict.keys(): - normal_dict = rename(normal_dict, key, key.replace("netG", "netG.normal_filter")) - - normal_dict = {k: v for k, v in normal_dict.items() if k in model_dict and v.shape == model_dict[k].shape} - print(colored(f"Resume normal model from {cfg.normal_path}", "green")) - - model_dict.update(main_dict) - model_dict.update(normal_dict) - model.load_state_dict(model_dict) - - model.netG = model.netG.to(device) - model.reconEngine = model.reconEngine.to(device) - - model.netG.training = False - model.netG.eval() - - del main_dict - del normal_dict - del model_dict - - torch.cuda.empty_cache() - - return model + losses["lapla"]["value"] = mesh_laplacian_smoothing(mesh, method="uniform") def read_smpl_constants(folder): @@ -706,8 +433,10 @@ def read_smpl_constants(folder): smpl_vertex_code = np.float32(np.copy(smpl_vtx_std)) """Load smpl faces & tetrahedrons""" smpl_faces = np.loadtxt(os.path.join(folder, "faces.txt"), dtype=np.int32) - 1 - smpl_face_code = (smpl_vertex_code[smpl_faces[:, 0]] + smpl_vertex_code[smpl_faces[:, 1]] + - smpl_vertex_code[smpl_faces[:, 2]]) / 3.0 + smpl_face_code = ( + smpl_vertex_code[smpl_faces[:, 0]] + smpl_vertex_code[smpl_faces[:, 1]] + + smpl_vertex_code[smpl_faces[:, 2]] + ) / 3.0 smpl_tetras = (np.loadtxt(os.path.join(folder, "tetrahedrons.txt"), dtype=np.int32) - 1) return_dict = { @@ -720,19 +449,6 @@ def read_smpl_constants(folder): return return_dict -def feat_select(feat, select): - - # feat [B, featx2, N] - # select [B, 1, N] - # return [B, feat, N] - - dim = feat.shape[1] // 2 - idx = torch.tile((1 - select), (1, dim, 1)) * dim + torch.arange(0, dim).unsqueeze(0).unsqueeze(2).type_as(select) - feat_select = torch.gather(feat, 1, idx.long()) - - return feat_select - - def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel=1): """get the visibility of vertices @@ -771,7 +487,9 @@ def get_visibility(xy, z, faces, img_res=2**12, blur_radius=0.0, faces_per_pixel for idx in range(N_body): Num_faces = len(faces[idx]) - vis_vertices_id = torch.unique(faces[idx][torch.unique(pix_to_face[idx][pix_to_face[idx] != -1]) - Num_faces * idx, :]) + vis_vertices_id = torch.unique( + faces[idx][torch.unique(pix_to_face[idx][pix_to_face[idx] != -1]) - Num_faces * idx, :] + ) vis_mask[idx, vis_vertices_id] = 1.0 # print("------------------------\n") @@ -825,7 +543,7 @@ def orthogonal(points, calibrations, transforms=None): """ rot = calibrations[:, :3, :3] trans = calibrations[:, :3, 3:4] - pts = torch.baddbmm(trans, rot, points) # [B, 3, N] + pts = torch.baddbmm(trans, rot, points) # [B, 3, N] if transforms is not None: scale = transforms[:2, :2] shift = transforms[:2, 2:3] @@ -925,37 +643,14 @@ def compute_normal_batch(vertices, faces): return vert_norm -def calculate_mIoU(outputs, labels): - - SMOOTH = 1e-6 - - outputs = outputs.int() - labels = labels.int() - - intersection = ((outputs & labels).float().sum()) # Will be zero if Truth=0 or Prediction=0 - union = (outputs | labels).float().sum() # Will be zzero if both are 0 - - iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0 - - thresholded = (torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10) # This is equal to comparing with thresolds - - return (thresholded.mean().detach().cpu().numpy() - ) # Or thresholded.mean() if you are interested in average across the batch - - -def add_alpha(colors, alpha=0.7): - - colors_pad = np.pad(colors, ((0, 0), (0, 1)), mode="constant", constant_values=alpha) - - return colors_pad - - def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"): font_path = os.path.join(os.path.dirname(__file__), "tbfo.ttf") font = ImageFont.truetype(font_path, 30) grid_img = torchvision.utils.make_grid(torch.cat(per_loop_lst, dim=0), nrow=nrow, padding=0) - grid_img = Image.fromarray(((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * 255.0).astype(np.uint8)) + grid_img = Image.fromarray( + ((grid_img.permute(1, 2, 0).detach().cpu().numpy() + 1.0) * 0.5 * 255.0).astype(np.uint8) + ) if False: # add text @@ -965,16 +660,20 @@ def get_optim_grid_image(per_loop_lst, loss=None, nrow=4, type="smpl"): draw.text((10, 5), f"error: {loss:.3f}", (255, 0, 0), font=font) if type == "smpl": - for col_id, col_txt in enumerate([ + for col_id, col_txt in enumerate( + [ "image", "smpl-norm(render)", "cloth-norm(pred)", "diff-norm", "diff-mask", - ]): + ] + ): draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font) elif type == "cloth": - for col_id, col_txt in enumerate(["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"]): + for col_id, col_txt in enumerate( + ["image", "cloth-norm(recon)", "cloth-norm(pred)", "diff-norm"] + ): draw.text((10 + (col_id * grid_size), 5), col_txt, (255, 0, 0), font=font) for col_id, col_txt in enumerate(["0", "90", "180", "270"]): draw.text( @@ -996,12 +695,9 @@ def clean_mesh(verts, faces): device = verts.device mesh_lst = trimesh.Trimesh(verts.detach().cpu().numpy(), faces.detach().cpu().numpy()) - mesh_lst = mesh_lst.split(only_watertight=False) - comp_num = [mesh.vertices.shape[0] for mesh in mesh_lst] - - mesh_clean = mesh_lst[comp_num.index(max(comp_num))] - final_verts = torch.as_tensor(mesh_clean.vertices).float().to(device) - final_faces = torch.as_tensor(mesh_clean.faces).long().to(device) + largest_mesh = keep_largest(mesh_lst) + final_verts = torch.as_tensor(largest_mesh.vertices).float().to(device) + final_faces = torch.as_tensor(largest_mesh.faces).long().to(device) return final_verts, final_faces diff --git a/lib/net/BasePIFuNet.py b/lib/net/BasePIFuNet.py index 6793a1d771fe6be62d38d9a9ea621002195b8ab9..eb18dbb3245d57c9e030c18094322a58e874db93 100644 --- a/lib/net/BasePIFuNet.py +++ b/lib/net/BasePIFuNet.py @@ -21,11 +21,10 @@ from .geometry import index, orthogonal, perspective class BasePIFuNet(pl.LightningModule): - def __init__( - self, - projection_mode="orthogonal", - error_term=nn.MSELoss(), + self, + projection_mode="orthogonal", + error_term=nn.MSELoss(), ): """ :param projection_mode: diff --git a/lib/net/Discriminator.py b/lib/net/Discriminator.py index 83dc1ac393e2fb9130b3f8904f84e76f83329f98..c60acdde000d414c78af0705ba268af3117c6ec9 100644 --- a/lib/net/Discriminator.py +++ b/lib/net/Discriminator.py @@ -9,17 +9,18 @@ from lib.torch_utils.ops.native_ops import FusedLeakyReLU, fused_leaky_relu, upf class DiscriminatorHead(nn.Module): - def __init__(self, in_channel, disc_stddev=False): super().__init__() self.disc_stddev = disc_stddev stddev_dim = 1 if disc_stddev else 0 - self.conv_stddev = ConvLayer2d(in_channel=in_channel + stddev_dim, - out_channel=in_channel, - kernel_size=3, - activate=True) + self.conv_stddev = ConvLayer2d( + in_channel=in_channel + stddev_dim, + out_channel=in_channel, + kernel_size=3, + activate=True + ) self.final_linear = nn.Sequential( nn.Flatten(), @@ -32,8 +33,8 @@ class DiscriminatorHead(nn.Module): inv_perm = torch.argsort(perm) batch, channel, height, width = x.shape - x = x[ - perm] # shuffle inputs so that all views in a single trajectory don't get put together + x = x[perm + ] # shuffle inputs so that all views in a single trajectory don't get put together group = min(batch, stddev_group) stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width) @@ -41,7 +42,7 @@ class DiscriminatorHead(nn.Module): stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) - stddev = stddev[inv_perm] # reorder inputs + stddev = stddev[inv_perm] # reorder inputs x = x[inv_perm] out = torch.cat([x, stddev], 1) @@ -56,7 +57,6 @@ class DiscriminatorHead(nn.Module): class ConvDecoder(nn.Module): - def __init__(self, in_channel, out_channel, in_res, out_res): super().__init__() @@ -68,20 +68,22 @@ class ConvDecoder(nn.Module): for i in range(log_size_in, log_size_out): out_ch = in_ch // 2 self.layers.append( - ConvLayer2d(in_channel=in_ch, - out_channel=out_ch, - kernel_size=3, - upsample=True, - bias=True, - activate=True)) + ConvLayer2d( + in_channel=in_ch, + out_channel=out_ch, + kernel_size=3, + upsample=True, + bias=True, + activate=True + ) + ) in_ch = out_ch self.layers.append( - ConvLayer2d(in_channel=in_ch, - out_channel=out_channel, - kernel_size=3, - bias=True, - activate=False)) + ConvLayer2d( + in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False + ) + ) self.layers = nn.Sequential(*self.layers) def forward(self, x): @@ -89,7 +91,6 @@ class ConvDecoder(nn.Module): class StyleDiscriminator(nn.Module): - def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs): super().__init__() @@ -104,7 +105,8 @@ class StyleDiscriminator(nn.Module): for i in range(log_size_in, log_size_out, -1): out_channels = int(min(in_channels * 2, ch_max)) self.layers.append( - ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)) + ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True) + ) in_channels = out_channels self.layers = nn.Sequential(*self.layers) @@ -147,7 +149,6 @@ class Blur(nn.Module): Upsample factor. """ - def __init__(self, kernel, pad, upsample_factor=1): super().__init__() @@ -177,7 +178,6 @@ class Upsample(nn.Module): Upsampling factor. """ - def __init__(self, kernel=[1, 3, 3, 1], factor=2): super().__init__() @@ -208,7 +208,6 @@ class Downsample(nn.Module): Downsampling factor. """ - def __init__(self, kernel=[1, 3, 3, 1], factor=2): super().__init__() @@ -250,7 +249,6 @@ class EqualLinear(nn.Module): Apply leakyReLU activation. """ - def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False): super().__init__() @@ -300,7 +298,6 @@ class EqualConv2d(nn.Module): Use bias term. """ - def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): super().__init__() @@ -316,16 +313,20 @@ class EqualConv2d(nn.Module): self.bias = None def forward(self, input): - out = F.conv2d(input, - self.weight * self.scale, - bias=self.bias, - stride=self.stride, - padding=self.padding) + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding + ) return out def __repr__(self): - return (f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," - f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})") + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) class EqualConvTranspose2d(nn.Module): @@ -353,15 +354,16 @@ class EqualConvTranspose2d(nn.Module): Use bias term. """ - - def __init__(self, - in_channel, - out_channel, - kernel_size, - stride=1, - padding=0, - output_padding=0, - bias=True): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + output_padding=0, + bias=True + ): super().__init__() self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size)) @@ -388,12 +390,13 @@ class EqualConvTranspose2d(nn.Module): return out def __repr__(self): - return (f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},' - f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})') + return ( + f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) class ConvLayer2d(nn.Sequential): - def __init__( self, in_channel, @@ -415,12 +418,15 @@ class ConvLayer2d(nn.Sequential): pad1 = p // 2 + 1 layers.append( - EqualConvTranspose2d(in_channel, - out_channel, - kernel_size, - padding=0, - stride=2, - bias=bias and not activate)) + EqualConvTranspose2d( + in_channel, + out_channel, + kernel_size, + padding=0, + stride=2, + bias=bias and not activate + ) + ) layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)) if downsample: @@ -431,23 +437,29 @@ class ConvLayer2d(nn.Sequential): layers.append(Blur(blur_kernel, pad=(pad0, pad1))) layers.append( - EqualConv2d(in_channel, - out_channel, - kernel_size, - padding=0, - stride=2, - bias=bias and not activate)) + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=0, + stride=2, + bias=bias and not activate + ) + ) if (not downsample) and (not upsample): padding = kernel_size // 2 layers.append( - EqualConv2d(in_channel, - out_channel, - kernel_size, - padding=padding, - stride=1, - bias=bias and not activate)) + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=padding, + stride=1, + bias=bias and not activate + ) + ) if activate: layers.append(FusedLeakyReLU(out_channel, bias=bias)) @@ -472,7 +484,6 @@ class ConvResBlock2d(nn.Module): Apply downsampling via strided convolution in the second conv. """ - def __init__(self, in_channel, out_channel, upsample=False, downsample=False): super().__init__() diff --git a/lib/net/FBNet.py b/lib/net/FBNet.py index 0122d2fb47aa7316075d29f10cd8fe8012e7862a..f4797667d4d800019967d7ee2ed944ec8b8550fc 100644 --- a/lib/net/FBNet.py +++ b/lib/net/FBNet.py @@ -51,17 +51,17 @@ def get_norm_layer(norm_type="instance"): def define_G( - input_nc, - output_nc, - ngf, - netG, - n_downsample_global=3, - n_blocks_global=9, - n_local_enhancers=1, - n_blocks_local=3, - norm="instance", - gpu_ids=[], - last_op=nn.Tanh(), + input_nc, + output_nc, + ngf, + netG, + n_downsample_global=3, + n_blocks_global=9, + n_local_enhancers=1, + n_blocks_local=3, + norm="instance", + gpu_ids=[], + last_op=nn.Tanh(), ): norm_layer = get_norm_layer(norm_type=norm) if netG == "global": @@ -97,17 +97,20 @@ def define_G( return netG -def define_D(input_nc, - ndf, - n_layers_D, - norm='instance', - use_sigmoid=False, - num_D=1, - getIntermFeat=False, - gpu_ids=[]): +def define_D( + input_nc, + ndf, + n_layers_D, + norm='instance', + use_sigmoid=False, + num_D=1, + getIntermFeat=False, + gpu_ids=[] +): norm_layer = get_norm_layer(norm_type=norm) - netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, - getIntermFeat) + netD = MultiscaleDiscriminator( + input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat + ) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) netD.cuda(gpu_ids[0]) @@ -129,7 +132,6 @@ def print_network(net): # Generator ############################################################################## class LocalEnhancer(pl.LightningModule): - def __init__( self, input_nc, @@ -155,8 +157,9 @@ class LocalEnhancer(pl.LightningModule): n_blocks_global, norm_layer, ).model - model_global = [model_global[i] for i in range(len(model_global) - 3) - ] # get rid of final convolution layers + model_global = [ + model_global[i] for i in range(len(model_global) - 3) + ] # get rid of final convolution layers self.model = nn.Sequential(*model_global) ###### local enhancer layers ##### @@ -224,17 +227,16 @@ class LocalEnhancer(pl.LightningModule): class GlobalGenerator(pl.LightningModule): - def __init__( - self, - input_nc, - output_nc, - ngf=64, - n_downsampling=3, - n_blocks=9, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - last_op=nn.Tanh(), + self, + input_nc, + output_nc, + ngf=64, + n_downsampling=3, + n_blocks=9, + norm_layer=nn.BatchNorm2d, + padding_type="reflect", + last_op=nn.Tanh(), ): assert n_blocks >= 0 super(GlobalGenerator, self).__init__() @@ -296,42 +298,49 @@ class GlobalGenerator(pl.LightningModule): # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): - - def __init__(self, - input_nc, - ndf=64, - n_layers=3, - norm_layer=nn.BatchNorm2d, - use_sigmoid=False, - getIntermFeat=False): + def __init__( + self, + input_nc, + ndf=64, + n_layers=3, + norm_layer=nn.BatchNorm2d, + use_sigmoid=False, + getIntermFeat=False + ): super(NLayerDiscriminator, self).__init__() self.getIntermFeat = getIntermFeat self.n_layers = n_layers kw = 4 padw = int(np.ceil((kw - 1.0) / 2)) - sequence = [[ - nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), - nn.LeakyReLU(0.2, True) - ]] + sequence = [ + [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + ] nf = ndf for n in range(1, n_layers): nf_prev = nf nf = min(nf * 2, 512) - sequence += [[ - nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), - norm_layer(nf), - nn.LeakyReLU(0.2, True) - ]] + sequence += [ + [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + ] nf_prev = nf nf = min(nf * 2, 512) - sequence += [[ - nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), - norm_layer(nf), - nn.LeakyReLU(0.2, True) - ]] + sequence += [ + [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + ] sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] @@ -359,27 +368,30 @@ class NLayerDiscriminator(nn.Module): class MultiscaleDiscriminator(pl.LightningModule): - - def __init__(self, - input_nc, - ndf=64, - n_layers=3, - norm_layer=nn.BatchNorm2d, - use_sigmoid=False, - num_D=3, - getIntermFeat=False): + def __init__( + self, + input_nc, + ndf=64, + n_layers=3, + norm_layer=nn.BatchNorm2d, + use_sigmoid=False, + num_D=3, + getIntermFeat=False + ): super(MultiscaleDiscriminator, self).__init__() self.num_D = num_D self.n_layers = n_layers self.getIntermFeat = getIntermFeat for i in range(num_D): - netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, - getIntermFeat) + netD = NLayerDiscriminator( + input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat + ) if getIntermFeat: for j in range(n_layers + 2): - setattr(self, 'scale' + str(i) + '_layer' + str(j), - getattr(netD, 'model' + str(j))) + setattr( + self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)) + ) else: setattr(self, 'layer' + str(i), netD.model) @@ -414,11 +426,11 @@ class MultiscaleDiscriminator(pl.LightningModule): # Define a resnet block class ResnetBlock(pl.LightningModule): - def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): super(ResnetBlock, self).__init__() - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, - use_dropout) + self.conv_block = self.build_conv_block( + dim, padding_type, norm_layer, activation, use_dropout + ) def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): conv_block = [] @@ -459,7 +471,6 @@ class ResnetBlock(pl.LightningModule): class Encoder(pl.LightningModule): - def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): super(Encoder, self).__init__() self.output_nc = output_nc @@ -510,18 +521,17 @@ class Encoder(pl.LightningModule): inst_list = np.unique(inst.cpu().numpy().astype(int)) for i in inst_list: for b in range(input.size()[0]): - indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4 + indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4 for j in range(self.output_nc): output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], - indices[:, 3],] + indices[:, 3], ] mean_feat = torch.mean(output_ins).expand_as(output_ins) outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], - indices[:, 3],] = mean_feat + indices[:, 3], ] = mean_feat return outputs_mean class Vgg19(nn.Module): - def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features @@ -555,7 +565,6 @@ class Vgg19(nn.Module): class VGG19FeatLayer(nn.Module): - def __init__(self): super(VGG19FeatLayer, self).__init__() self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval() @@ -593,7 +602,6 @@ class VGG19FeatLayer(nn.Module): class VGGLoss(pl.LightningModule): - def __init__(self): super(VGGLoss, self).__init__() self.vgg = Vgg19().eval() @@ -609,11 +617,7 @@ class VGGLoss(pl.LightningModule): class GANLoss(pl.LightningModule): - - def __init__(self, - use_lsgan=True, - target_real_label=1.0, - target_fake_label=0.0): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label @@ -628,16 +632,18 @@ class GANLoss(pl.LightningModule): def get_target_tensor(self, input, target_is_real): target_tensor = None if target_is_real: - create_label = ((self.real_label_var is None) or - (self.real_label_var.numel() != input.numel())) + create_label = ( + (self.real_label_var is None) or (self.real_label_var.numel() != input.numel()) + ) if create_label: real_tensor = self.tensor(input.size()).fill_(self.real_label) self.real_label_var = real_tensor self.real_label_var.requires_grad = False target_tensor = self.real_label_var else: - create_label = ((self.fake_label_var is None) or - (self.fake_label_var.numel() != input.numel())) + create_label = ( + (self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel()) + ) if create_label: fake_tensor = self.tensor(input.size()).fill_(self.fake_label) self.fake_label_var = fake_tensor @@ -659,7 +665,6 @@ class GANLoss(pl.LightningModule): class IDMRFLoss(pl.LightningModule): - def __init__(self, featlayer=VGG19FeatLayer): super(IDMRFLoss, self).__init__() self.featlayer = featlayer() @@ -678,7 +683,8 @@ class IDMRFLoss(pl.LightningModule): patch_size = 1 patch_stride = 1 patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold( - 3, patch_size, patch_stride) + 3, patch_size, patch_stride + ) self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5) dims = self.patches_OIHW.size() self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5]) @@ -743,7 +749,8 @@ class IDMRFLoss(pl.LightningModule): self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_content_layers ] - self.content_loss = functools.reduce(lambda x, y: x + y, - content_loss_list) * self.lambda_content + self.content_loss = functools.reduce( + lambda x, y: x + y, content_loss_list + ) * self.lambda_content return self.style_loss + self.content_loss diff --git a/lib/net/GANLoss.py b/lib/net/GANLoss.py index 9be907f5c3f74a3a05fd9a52913325ce54b09a9f..5d6711479980e89a3fc067b5ef579bb382eb29df 100644 --- a/lib/net/GANLoss.py +++ b/lib/net/GANLoss.py @@ -32,13 +32,12 @@ def logistic_loss(fake_pred, real_pred, mode): def r1_loss(real_pred, real_img): - (grad_real,) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True) + (grad_real, ) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True) grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return grad_penalty class GANLoss(nn.Module): - def __init__( self, opt, @@ -64,7 +63,7 @@ class GANLoss(nn.Module): logits_fake = self.discriminator(disc_in_fake) disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d') - + log = { "disc_loss": disc_loss.detach(), "logits_real": logits_real.mean().detach(), diff --git a/lib/net/IFGeoNet.py b/lib/net/IFGeoNet.py index 195953d0ed91aa7663040dadad3a757bf1086699..a72be083da26093093fb2da46dade7ace3df1bae 100644 --- a/lib/net/IFGeoNet.py +++ b/lib/net/IFGeoNet.py @@ -8,20 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX class SelfAttention(torch.nn.Module): - def __init__(self, in_channels, out_channels): super().__init__() - self.conv = nn.Conv3d(in_channels, - out_channels, - 3, - padding=1, - padding_mode='replicate') - self.attention = nn.Conv3d(in_channels, - out_channels, - kernel_size=3, - padding=1, - padding_mode='replicate', - bias=False) + self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate') + self.attention = nn.Conv3d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + padding_mode='replicate', + bias=False + ) with torch.no_grad(): self.attention.weight.copy_(torch.zeros_like(self.attention.weight)) @@ -32,38 +29,45 @@ class SelfAttention(torch.nn.Module): class IFGeoNet(nn.Module): - def __init__(self, cfg, hidden_dim=256): super(IFGeoNet, self).__init__() - self.conv_in_partial = nn.Conv3d(1, 16, 3, padding=1, - padding_mode='replicate') # out: 256 ->m.p. 128 + self.conv_in_partial = nn.Conv3d( + 1, 16, 3, padding=1, padding_mode='replicate' + ) # out: 256 ->m.p. 128 - self.conv_in_smpl = nn.Conv3d(1, 4, 3, padding=1, - padding_mode='replicate') # out: 256 ->m.p. 128 + self.conv_in_smpl = nn.Conv3d( + 1, 4, 3, padding=1, padding_mode='replicate' + ) # out: 256 ->m.p. 128 self.SA = SelfAttention(4, 4) - self.conv_0_fusion = nn.Conv3d(16 + 4, 32, 3, padding=1, - padding_mode='replicate') # out: 128 - self.conv_0_1_fusion = nn.Conv3d(32, 32, 3, padding=1, - padding_mode='replicate') # out: 128 ->m.p. 64 - - self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128 - self.conv_0_1 = nn.Conv3d(32, 32, 3, padding=1, - padding_mode='replicate') # out: 128 ->m.p. 64 - - self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64 - self.conv_1_1 = nn.Conv3d(64, 64, 3, padding=1, - padding_mode='replicate') # out: 64 -> mp 32 - - self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32 - self.conv_2_1 = nn.Conv3d(128, 128, 3, padding=1, - padding_mode='replicate') # out: 32 -> mp 16 - self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16 - self.conv_3_1 = nn.Conv3d(128, 128, 3, padding=1, - padding_mode='replicate') # out: 16 -> mp 8 - self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 - self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 + self.conv_0_fusion = nn.Conv3d( + 16 + 4, 32, 3, padding=1, padding_mode='replicate' + ) # out: 128 + self.conv_0_1_fusion = nn.Conv3d( + 32, 32, 3, padding=1, padding_mode='replicate' + ) # out: 128 ->m.p. 64 + + self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128 + self.conv_0_1 = nn.Conv3d( + 32, 32, 3, padding=1, padding_mode='replicate' + ) # out: 128 ->m.p. 64 + + self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64 + self.conv_1_1 = nn.Conv3d( + 64, 64, 3, padding=1, padding_mode='replicate' + ) # out: 64 -> mp 32 + + self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32 + self.conv_2_1 = nn.Conv3d( + 128, 128, 3, padding=1, padding_mode='replicate' + ) # out: 32 -> mp 16 + self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16 + self.conv_3_1 = nn.Conv3d( + 128, 128, 3, padding=1, padding_mode='replicate' + ) # out: 16 -> mp 8 + self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 + self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3 self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1) @@ -97,21 +101,21 @@ class IFGeoNet(nn.Module): smooth_kernel_size=7, batch_size=cfg.batch_size, ) - + self.l1_loss = nn.SmoothL1Loss() def forward(self, batch): - + if "body_voxels" in batch.keys(): x_smpl = batch["body_voxels"] else: with torch.no_grad(): self.voxelization.update_param(batch["voxel_faces"]) - x_smpl = self.voxelization(batch["voxel_verts"])[:, 0] #[B, 128, 128, 128] - + x_smpl = self.voxelization(batch["voxel_verts"])[:, 0] #[B, 128, 128, 128] + p = orthogonal(batch["samples_geo"].permute(0, 2, 1), - batch["calib"]).permute(0, 2, 1) #[2, 60000, 3] - x = batch["depth_voxels"] #[B, 128, 128, 128] + batch["calib"]).permute(0, 2, 1) #[2, 60000, 3] + x = batch["depth_voxels"] #[B, 128, 128, 128] x = x.unsqueeze(1) x_smpl = x_smpl.unsqueeze(1) @@ -119,63 +123,67 @@ class IFGeoNet(nn.Module): p = p.unsqueeze(1).unsqueeze(1) # partial inputs feature extraction - feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners = True) + feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True) net_partial = self.actvn(self.conv_in_partial(x)) net_partial = self.partial_conv_in_bn(net_partial) - net_partial = self.maxpool(net_partial) # out 64 + net_partial = self.maxpool(net_partial) # out 64 # smpl inputs feature extraction # feature_0_smpl = F.grid_sample(x_smpl, p, padding_mode='border', align_corners = True) net_smpl = self.actvn(self.conv_in_smpl(x_smpl)) net_smpl = self.smpl_conv_in_bn(net_smpl) - net_smpl = self.maxpool(net_smpl) # out 64 + net_smpl = self.maxpool(net_smpl) # out 64 net_smpl = self.SA(net_smpl) - + # Feature fusion net = self.actvn(self.conv_0_fusion(torch.concat([net_partial, net_smpl], dim=1))) net = self.actvn(self.conv_0_1_fusion(net)) net = self.conv0_1_bn_fusion(net) - feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners = True) + feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners=True) # net = self.maxpool(net) # out 64 net = self.actvn(self.conv_0(net)) net = self.actvn(self.conv_0_1(net)) net = self.conv0_1_bn(net) - feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners = True) - net = self.maxpool(net) # out 32 + feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True) + net = self.maxpool(net) # out 32 net = self.actvn(self.conv_1(net)) net = self.actvn(self.conv_1_1(net)) net = self.conv1_1_bn(net) - feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners = True) - net = self.maxpool(net) # out 16 + feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True) + net = self.maxpool(net) # out 16 net = self.actvn(self.conv_2(net)) net = self.actvn(self.conv_2_1(net)) net = self.conv2_1_bn(net) - feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners = True) - net = self.maxpool(net) # out 8 + feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True) + net = self.maxpool(net) # out 8 net = self.actvn(self.conv_3(net)) net = self.actvn(self.conv_3_1(net)) net = self.conv3_1_bn(net) - feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners = True) - net = self.maxpool(net) # out 4 + feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True) + net = self.maxpool(net) # out 4 net = self.actvn(self.conv_4(net)) net = self.actvn(self.conv_4_1(net)) net = self.conv4_1_bn(net) - feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners = True) # out 2 + feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2 # here every channel corresponse to one feature. - features = torch.cat((feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, - feature_5, feature_6), - dim=1) # (B, features, 1,7,sample_num) + features = torch.cat( + ( + feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5, + feature_6 + ), + dim=1 + ) # (B, features, 1,7,sample_num) shape = features.shape features = torch.reshape( - features, - (shape[0], shape[1] * shape[3], shape[4])) # (B, featues_per_sample, samples_num) + features, (shape[0], shape[1] * shape[3], shape[4]) + ) # (B, featues_per_sample, samples_num) # (B, featue_size, samples_num) features = torch.cat((features, p_features), dim=1) @@ -183,7 +191,7 @@ class IFGeoNet(nn.Module): net = self.actvn(self.fc_1(net)) net = self.actvn(self.fc_2(net)) net = self.fc_out(net).squeeze(1) - + return net def compute_loss(self, prds, tgts): diff --git a/lib/net/IFGeoNet_nobody.py b/lib/net/IFGeoNet_nobody.py index bf83b5c09557294a025f975241068f3cf03d19b6..ceda5dedfcf09167f670a66a91b152f6181631cc 100644 --- a/lib/net/IFGeoNet_nobody.py +++ b/lib/net/IFGeoNet_nobody.py @@ -8,16 +8,17 @@ from lib.dataset.mesh_util import read_smpl_constants, SMPLX class SelfAttention(torch.nn.Module): - def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate') - self.attention = nn.Conv3d(in_channels, - out_channels, - kernel_size=3, - padding=1, - padding_mode='replicate', - bias=False) + self.attention = nn.Conv3d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + padding_mode='replicate', + bias=False + ) with torch.no_grad(): self.attention.weight.copy_(torch.zeros_like(self.attention.weight)) @@ -28,34 +29,39 @@ class SelfAttention(torch.nn.Module): class IFGeoNet(nn.Module): - def __init__(self, cfg, hidden_dim=256): super(IFGeoNet, self).__init__() - self.conv_in_partial = nn.Conv3d(1, 16, 3, padding=1, - padding_mode='replicate') # out: 256 ->m.p. 128 + self.conv_in_partial = nn.Conv3d( + 1, 16, 3, padding=1, padding_mode='replicate' + ) # out: 256 ->m.p. 128 self.SA = SelfAttention(4, 4) - self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') # out: 128 - self.conv_0_1_fusion = nn.Conv3d(32, 32, 3, padding=1, - padding_mode='replicate') # out: 128 ->m.p. 64 - - self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128 - self.conv_0_1 = nn.Conv3d(32, 32, 3, padding=1, - padding_mode='replicate') # out: 128 ->m.p. 64 - - self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64 - self.conv_1_1 = nn.Conv3d(64, 64, 3, padding=1, - padding_mode='replicate') # out: 64 -> mp 32 - - self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32 - self.conv_2_1 = nn.Conv3d(128, 128, 3, padding=1, - padding_mode='replicate') # out: 32 -> mp 16 - self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16 - self.conv_3_1 = nn.Conv3d(128, 128, 3, padding=1, - padding_mode='replicate') # out: 16 -> mp 8 - self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 - self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 + self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') # out: 128 + self.conv_0_1_fusion = nn.Conv3d( + 32, 32, 3, padding=1, padding_mode='replicate' + ) # out: 128 ->m.p. 64 + + self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') # out: 128 + self.conv_0_1 = nn.Conv3d( + 32, 32, 3, padding=1, padding_mode='replicate' + ) # out: 128 ->m.p. 64 + + self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') # out: 64 + self.conv_1_1 = nn.Conv3d( + 64, 64, 3, padding=1, padding_mode='replicate' + ) # out: 64 -> mp 32 + + self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') # out: 32 + self.conv_2_1 = nn.Conv3d( + 128, 128, 3, padding=1, padding_mode='replicate' + ) # out: 32 -> mp 16 + self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 16 + self.conv_3_1 = nn.Conv3d( + 128, 128, 3, padding=1, padding_mode='replicate' + ) # out: 16 -> mp 8 + self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 + self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') # out: 8 feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3 self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1) @@ -95,8 +101,8 @@ class IFGeoNet(nn.Module): def forward(self, batch): p = orthogonal(batch["samples_geo"].permute(0, 2, 1), - batch["calib"]).permute(0, 2, 1) #[2, 60000, 3] - x = batch["depth_voxels"] #[B, 128, 128, 128] + batch["calib"]).permute(0, 2, 1) #[2, 60000, 3] + x = batch["depth_voxels"] #[B, 128, 128, 128] x = x.unsqueeze(1) p_features = p.transpose(1, -1) @@ -106,7 +112,7 @@ class IFGeoNet(nn.Module): feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True) net_partial = self.actvn(self.conv_in_partial(x)) net_partial = self.partial_conv_in_bn(net_partial) - net_partial = self.maxpool(net_partial) # out 64 + net_partial = self.maxpool(net_partial) # out 64 # Feature fusion net = self.actvn(self.conv_0_fusion(net_partial)) @@ -119,40 +125,44 @@ class IFGeoNet(nn.Module): net = self.actvn(self.conv_0_1(net)) net = self.conv0_1_bn(net) feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True) - net = self.maxpool(net) # out 32 + net = self.maxpool(net) # out 32 net = self.actvn(self.conv_1(net)) net = self.actvn(self.conv_1_1(net)) net = self.conv1_1_bn(net) feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True) - net = self.maxpool(net) # out 16 + net = self.maxpool(net) # out 16 net = self.actvn(self.conv_2(net)) net = self.actvn(self.conv_2_1(net)) net = self.conv2_1_bn(net) feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True) - net = self.maxpool(net) # out 8 + net = self.maxpool(net) # out 8 net = self.actvn(self.conv_3(net)) net = self.actvn(self.conv_3_1(net)) net = self.conv3_1_bn(net) feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True) - net = self.maxpool(net) # out 4 + net = self.maxpool(net) # out 4 net = self.actvn(self.conv_4(net)) net = self.actvn(self.conv_4_1(net)) net = self.conv4_1_bn(net) - feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2 + feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) # out 2 # here every channel corresponse to one feature. - features = torch.cat((feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, - feature_5, feature_6), - dim=1) # (B, features, 1,7,sample_num) + features = torch.cat( + ( + feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5, + feature_6 + ), + dim=1 + ) # (B, features, 1,7,sample_num) shape = features.shape features = torch.reshape( - features, - (shape[0], shape[1] * shape[3], shape[4])) # (B, featues_per_sample, samples_num) + features, (shape[0], shape[1] * shape[3], shape[4]) + ) # (B, featues_per_sample, samples_num) # (B, featue_size, samples_num) features = torch.cat((features, p_features), dim=1) @@ -167,4 +177,4 @@ class IFGeoNet(nn.Module): loss = self.l1_loss(prds, tgts) - return loss \ No newline at end of file + return loss diff --git a/lib/net/NormalNet.py b/lib/net/NormalNet.py index 71620076ace11b3542763f9b5285ec58a86bb949..a065840ed859137e72ba1e37a40c636da7c32e6f 100644 --- a/lib/net/NormalNet.py +++ b/lib/net/NormalNet.py @@ -35,7 +35,6 @@ class NormalNet(BasePIFuNet): 4. Classification. 5. During training, error is calculated on all stacks. """ - def __init__(self, cfg): super(NormalNet, self).__init__() @@ -65,9 +64,11 @@ class NormalNet(BasePIFuNet): item[0] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image" ] self.in_nmlF_dim = sum( - [item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"]) + [item[1] for item in self.opt.in_nml if "_F" in item[0] or item[0] == "image"] + ) self.in_nmlB_dim = sum( - [item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"]) + [item[1] for item in self.opt.in_nml if "_B" in item[0] or item[0] == "image"] + ) self.netF = define_G(self.in_nmlF_dim, 3, 64, "global", 4, 9, 1, 3, "instance") self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, "instance") @@ -134,18 +135,20 @@ class NormalNet(BasePIFuNet): if 'mrf' in self.F_losses: mrf_F_loss = self.mrf_loss( F.interpolate(prd_F, scale_factor=scale_factor, mode='bicubic', align_corners=True), - F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True)) + F.interpolate(tgt_F, scale_factor=scale_factor, mode='bicubic', align_corners=True) + ) total_loss["netF"] += self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss total_loss["mrf_F"] = self.F_losses_ratio[self.F_losses.index('mrf')] * mrf_F_loss if 'mrf' in self.B_losses: mrf_B_loss = self.mrf_loss( F.interpolate(prd_B, scale_factor=scale_factor, mode='bicubic', align_corners=True), - F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True)) + F.interpolate(tgt_B, scale_factor=scale_factor, mode='bicubic', align_corners=True) + ) total_loss["netB"] += self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss total_loss["mrf_B"] = self.B_losses_ratio[self.B_losses.index('mrf')] * mrf_B_loss if 'gan' in self.ALL_losses: - + total_loss["netD"] = 0.0 pred_fake = self.netD.forward(prd_B) @@ -154,8 +157,8 @@ class NormalNet(BasePIFuNet): loss_D_real = self.gan_loss(pred_real, True) loss_G_fake = self.gan_loss(pred_fake, True) - total_loss["netD"] += 0.5 * ( - loss_D_fake + loss_D_real) * self.B_losses_ratio[self.B_losses.index('gan')] + total_loss["netD"] += 0.5 * (loss_D_fake + loss_D_real + ) * self.B_losses_ratio[self.B_losses.index('gan')] total_loss["D_fake"] = loss_D_fake * self.B_losses_ratio[self.B_losses.index('gan')] total_loss["D_real"] = loss_D_real * self.B_losses_ratio[self.B_losses.index('gan')] @@ -167,8 +170,8 @@ class NormalNet(BasePIFuNet): for i in range(2): for j in range(len(pred_fake[i]) - 1): loss_G_GAN_Feat += self.l1_loss(pred_fake[i][j], pred_real[i][j].detach()) - total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[self.B_losses.index( - 'gan_feat')] + total_loss["netB"] += loss_G_GAN_Feat * self.B_losses_ratio[ + self.B_losses.index('gan_feat')] total_loss["G_GAN_Feat"] = loss_G_GAN_Feat * self.B_losses_ratio[ self.B_losses.index('gan_feat')] diff --git a/lib/net/geometry.py b/lib/net/geometry.py index af6bf154723addf0565820468f32d8a2efa980a1..6d7d82d2cb6b760596d1bbf70804e542999f802e 100644 --- a/lib/net/geometry.py +++ b/lib/net/geometry.py @@ -19,12 +19,12 @@ import numpy as np import numbers from torch.nn import functional as F from einops.einops import rearrange - """ Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR """ + def quaternion_to_rotation_matrix(quat): """Convert quaternion coefficients to rotation matrix. Args: @@ -42,11 +42,13 @@ def quaternion_to_rotation_matrix(quat): wx, wy, wz = w * x, w * y, w * z xy, xz, yz = x * y, x * z, y * z - rotMat = torch.stack([ - w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, - 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 - ], - dim=1).view(B, 3, 3) + rotMat = torch.stack( + [ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 + ], + dim=1 + ).view(B, 3, 3) return rotMat @@ -56,7 +58,7 @@ def index(feat, uv): :param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1] :return: [B, C, N] image features at the uv coordinates """ - uv = uv.transpose(1, 2) # [B, N, 2] + uv = uv.transpose(1, 2) # [B, N, 2] (B, N, _) = uv.shape C = feat.shape[1] @@ -64,14 +66,14 @@ def index(feat, uv): if uv.shape[-1] == 3: # uv = uv[:,:,[2,1,0]] # uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...] - uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3] + uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3] else: - uv = uv.unsqueeze(2) # [B, N, 1, 2] + uv = uv.unsqueeze(2) # [B, N, 1, 2] # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample # for old versions, simply remove the aligned_corners argument. - samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1] - return samples.view(B, C, N) # [B, C, N] + samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1] + return samples.view(B, C, N) # [B, C, N] def orthogonal(points, calibrations, transforms=None): @@ -84,7 +86,7 @@ def orthogonal(points, calibrations, transforms=None): """ rot = calibrations[:, :3, :3] trans = calibrations[:, :3, 3:4] - pts = torch.baddbmm(trans, rot, points) # [B, 3, N] + pts = torch.baddbmm(trans, rot, points) # [B, 3, N] if transforms is not None: scale = transforms[:2, :2] shift = transforms[:2, 2:3] @@ -102,7 +104,7 @@ def perspective(points, calibrations, transforms=None): """ rot = calibrations[:, :3, :3] trans = calibrations[:, :3, 3:4] - homo = torch.baddbmm(trans, rot, points) # [B, 3, N] + homo = torch.baddbmm(trans, rot, points) # [B, 3, N] xy = homo[:, :2, :] / homo[:, 2:3, :] if transforms is not None: scale = transforms[:2, :2] @@ -187,7 +189,8 @@ def rotation_matrix_to_angle_axis(rotation_matrix): if rotation_matrix.shape[1:] == (3, 3): rot_mat = rotation_matrix.reshape(-1, 3, 3) hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape( - 1, 3, 1).expand(rot_mat.shape[0], -1, -1) + 1, 3, 1 + ).expand(rot_mat.shape[0], -1, -1) rotation_matrix = torch.cat([rot_mat, hom], dim=-1) quaternion = rotation_matrix_to_quaternion(rotation_matrix) @@ -222,8 +225,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) if not quaternion.shape[-1] == 4: - raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format( - quaternion.shape)) + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape) + ) # unpack input and compute conversion q1: torch.Tensor = quaternion[..., 1] q2: torch.Tensor = quaternion[..., 2] @@ -276,11 +280,13 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))) if len(rotation_matrix.shape) > 3: - raise ValueError("Input size must be a three dimensional tensor. Got {}".format( - rotation_matrix.shape)) + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape) + ) if not rotation_matrix.shape[-2:] == (3, 4): - raise ValueError("Input size must be a N x 3 x 4 tensor. Got {}".format( - rotation_matrix.shape)) + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format(rotation_matrix.shape) + ) rmat_t = torch.transpose(rotation_matrix, 1, 2) @@ -347,8 +353,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): mask_c3 = mask_c3.view(-1, 1).type_as(q3) q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 - q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa - + t3_rep * mask_c3) # noqa + q /= torch.sqrt( + t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa + + t3_rep * mask_c3 + ) # noqa q *= 0.5 return q @@ -389,6 +397,7 @@ def rot6d_to_rotmat(x): mat = torch.stack((b1, b2, b3), dim=-1) return mat + def rotmat_to_rot6d(x): """Convert 3x3 rotation matrix to 6D rotation representation. Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 @@ -402,6 +411,7 @@ def rotmat_to_rot6d(x): x = x.reshape(batch_size, 6) return x + def rotmat_to_angle(x): """Convert rotation to one-D angle. Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 @@ -440,12 +450,9 @@ def projection(pred_joints, pred_camera, retain_z=False): return pred_keypoints_2d -def perspective_projection(points, - rotation, - translation, - focal_length, - camera_center, - retain_z=False): +def perspective_projection( + points, rotation, translation, focal_length, camera_center, retain_z=False +): """ This function computes the perspective projection of a set of points. Input: @@ -501,10 +508,12 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1) # least squares - Q = np.array([ - F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints), - O - np.reshape(joints_2d, -1) - ]).T + Q = np.array( + [ + F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints), + O - np.reshape(joints_2d, -1) + ] + ).T c = (np.reshape(joints_2d, -1) - O) * Z - F * XY # weighted least squares @@ -558,15 +567,12 @@ def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_al S_i = S[i] joints_i = joints_2d[i] conf_i = joints_conf[i] - trans[i] = estimate_translation_np(S_i, - joints_i, - conf_i, - focal_length=focal_length[i], - img_size=img_size[i]) + trans[i] = estimate_translation_np( + S_i, joints_i, conf_i, focal_length=focal_length[i], img_size=img_size[i] + ) return torch.from_numpy(trans).to(device) - def Rot_y(angle, category="torch", prepend_dim=True, device=None): """Rotate around y-axis by angle Args: @@ -574,11 +580,13 @@ def Rot_y(angle, category="torch", prepend_dim=True, device=None): prepend_dim: prepend an extra dimension Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) """ - m = np.array([ - [np.cos(angle), 0.0, np.sin(angle)], - [0.0, 1.0, 0.0], - [-np.sin(angle), 0.0, np.cos(angle)], - ]) + m = np.array( + [ + [np.cos(angle), 0.0, np.sin(angle)], + [0.0, 1.0, 0.0], + [-np.sin(angle), 0.0, np.cos(angle)], + ] + ) if category == "torch": if prepend_dim: return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) @@ -600,11 +608,13 @@ def Rot_x(angle, category="torch", prepend_dim=True, device=None): prepend_dim: prepend an extra dimension Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) """ - m = np.array([ - [1.0, 0.0, 0.0], - [0.0, np.cos(angle), -np.sin(angle)], - [0.0, np.sin(angle), np.cos(angle)], - ]) + m = np.array( + [ + [1.0, 0.0, 0.0], + [0.0, np.cos(angle), -np.sin(angle)], + [0.0, np.sin(angle), np.cos(angle)], + ] + ) if category == "torch": if prepend_dim: return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) @@ -626,11 +636,13 @@ def Rot_z(angle, category="torch", prepend_dim=True, device=None): prepend_dim: prepend an extra dimension Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) """ - m = np.array([ - [np.cos(angle), -np.sin(angle), 0.0], - [np.sin(angle), np.cos(angle), 0.0], - [0.0, 0.0, 1.0], - ]) + m = np.array( + [ + [np.cos(angle), -np.sin(angle), 0.0], + [np.sin(angle), np.cos(angle), 0.0], + [0.0, 0.0, 1.0], + ] + ) if category == "torch": if prepend_dim: return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) @@ -672,7 +684,7 @@ def compute_twist_rotation(rotation_matrix, twist_axis): twist_rotation = quaternion_to_rotation_matrix(twist_quaternion) twist_aa = quaternion_to_angle_axis(twist_quaternion) - twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum( - twist_axis, dim=1, keepdim=True) + twist_angle = torch.sum(twist_aa, dim=1, + keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True) - return twist_rotation, twist_angle \ No newline at end of file + return twist_rotation, twist_angle diff --git a/lib/net/net_util.py b/lib/net/net_util.py index 200a87e5e09094069379f989082d8099d97b75f8..d89fcff5670909cd41c2e917e87b3bdb25870d8a 100644 --- a/lib/net/net_util.py +++ b/lib/net/net_util.py @@ -71,11 +71,10 @@ def init_weights(net, init_type="normal", init_gain=0.02): We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ - - def init_func(m): # define the initialization function + def init_func(m): # define the initialization function classname = m.__class__.__name__ - if hasattr(m, "weight") and (classname.find("Conv") != -1 or - classname.find("Linear") != -1): + if hasattr(m, + "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): if init_type == "normal": init.normal_(m.weight.data, 0.0, init_gain) elif init_type == "xavier": @@ -85,17 +84,19 @@ def init_weights(net, init_type="normal", init_gain=0.02): elif init_type == "orthogonal": init.orthogonal_(m.weight.data, gain=init_gain) else: - raise NotImplementedError("initialization method [%s] is not implemented" % - init_type) + raise NotImplementedError( + "initialization method [%s] is not implemented" % init_type + ) if hasattr(m, "bias") and m.bias is not None: init.constant_(m.bias.data, 0.0) - elif (classname.find("BatchNorm2d") != - -1): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + elif ( + classname.find("BatchNorm2d") != -1 + ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) # print('initialize network with %s' % init_type) - net.apply(init_func) # apply the initialization function + net.apply(init_func) # apply the initialization function def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]): @@ -110,7 +111,7 @@ def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]): """ if len(gpu_ids) > 0: assert torch.cuda.is_available() - net = torch.nn.DataParallel(net) # multi-GPUs + net = torch.nn.DataParallel(net) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net @@ -127,13 +128,9 @@ def imageSpaceRotation(xy, rot): return (disp * xy).sum(dim=1) -def cal_gradient_penalty(netD, - real_data, - fake_data, - device, - type="mixed", - constant=1.0, - lambda_gp=10.0): +def cal_gradient_penalty( + netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0 +): """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 Arguments: @@ -155,9 +152,11 @@ def cal_gradient_penalty(netD, interpolatesv = fake_data elif type == "mixed": alpha = torch.rand(real_data.shape[0], 1) - alpha = (alpha.expand(real_data.shape[0], - real_data.nelement() // - real_data.shape[0]).contiguous().view(*real_data.shape)) + alpha = ( + alpha.expand(real_data.shape[0], + real_data.nelement() // + real_data.shape[0]).contiguous().view(*real_data.shape) + ) alpha = alpha.to(device) interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) else: @@ -172,9 +171,9 @@ def cal_gradient_penalty(netD, retain_graph=True, only_inputs=True, ) - gradients = gradients[0].view(real_data.size(0), -1) # flat the data - gradient_penalty = (( - (gradients + 1e-16).norm(2, dim=1) - constant)**2).mean() * lambda_gp # added eps + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant)** + 2).mean() * lambda_gp # added eps return gradient_penalty, gradients else: return 0.0, None @@ -201,13 +200,11 @@ def get_norm_layer(norm_type="instance"): class Flatten(nn.Module): - def forward(self, input): return input.view(input.size(0), -1) class ConvBlock(nn.Module): - def __init__(self, in_planes, out_planes, opt): super(ConvBlock, self).__init__() [k, s, d, p] = opt.conv3x3 @@ -258,5 +255,3 @@ class ConvBlock(nn.Module): out3 += residual return out3 - - diff --git a/lib/net/voxelize.py b/lib/net/voxelize.py index ba341ce905455e14fa21456acfe899ba19e6f783..394b40e6eeeb158bb691c1e518b6b1f7a889b8d8 100644 --- a/lib/net/voxelize.py +++ b/lib/net/voxelize.py @@ -13,7 +13,6 @@ class VoxelizationFunction(Function): Definition of differentiable voxelization function Currently implemented only for cuda Tensors """ - @staticmethod def forward( ctx, @@ -48,12 +47,15 @@ class VoxelizationFunction(Function): smpl_face_code = smpl_face_code.contiguous() smpl_tetrahedrons = smpl_tetrahedrons.contiguous() - occ_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res, - ctx.volume_res).fill_(0.0) - semantic_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res, - ctx.volume_res, 3).fill_(0.0) - weight_sum_volume = torch.cuda.FloatTensor(ctx.batch_size, ctx.volume_res, ctx.volume_res, - ctx.volume_res).fill_(1e-3) + occ_volume = torch.cuda.FloatTensor( + ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res + ).fill_(0.0) + semantic_volume = torch.cuda.FloatTensor( + ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res, 3 + ).fill_(0.0) + weight_sum_volume = torch.cuda.FloatTensor( + ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res + ).fill_(1e-3) # occ_volume [B, volume_res, volume_res, volume_res] # semantic_volume [B, volume_res, volume_res, volume_res, 3] @@ -80,7 +82,6 @@ class Voxelization(nn.Module): """ Wrapper around the autograd function VoxelizationFunction """ - def __init__( self, smpl_vertex_code, @@ -151,21 +152,25 @@ class Voxelization(nn.Module): self.sigma, self.smooth_kernel_size, ) - return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw) + return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw) def vertices_to_faces(self, vertices): assert vertices.ndimension() == 3 bs, nv = vertices.shape[:2] - face = (self.smpl_face_indices_batch + - (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]) + face = ( + self.smpl_face_indices_batch + + (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None] + ) vertices_ = vertices.reshape((bs * nv, 3)) return vertices_[face.long()] def vertices_to_tetrahedrons(self, vertices): assert vertices.ndimension() == 3 bs, nv = vertices.shape[:2] - tets = (self.smpl_tetraderon_indices_batch + - (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]) + tets = ( + self.smpl_tetraderon_indices_batch + + (torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None] + ) vertices_ = vertices.reshape((bs * nv, 3)) return vertices_[tets.long()] @@ -174,8 +179,9 @@ class Voxelization(nn.Module): assert face_verts.shape[2] == 3 assert face_verts.shape[3] == 3 bs, nf = face_verts.shape[:2] - face_centers = (face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + - face_verts[:, :, 2, :]) / 3.0 + face_centers = ( + face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + face_verts[:, :, 2, :] + ) / 3.0 face_centers = face_centers.reshape((bs, nf, 3)) return face_centers diff --git a/lib/pixielib/models/FLAME.py b/lib/pixielib/models/FLAME.py index 6c0b78d5fd8efa12651ac5cb689b4e5c4d790636..b62b1069b6083685e8ff1511e57c48ccf79bc927 100755 --- a/lib/pixielib/models/FLAME.py +++ b/lib/pixielib/models/FLAME.py @@ -27,7 +27,6 @@ class FLAMETex(nn.Module): FLAME texture converted from BFM: https://github.com/TimoBolkart/BFM_to_FLAME """ - def __init__(self, config): super(FLAMETex, self).__init__() if config.tex_type == "BFM": @@ -54,8 +53,7 @@ class FLAMETex(nn.Module): n_tex = config.n_tex num_components = texture_basis.shape[1] texture_mean = torch.from_numpy(texture_mean).float()[None, ...] - texture_basis = torch.from_numpy( - texture_basis[:, :n_tex]).float()[None, ...] + texture_basis = torch.from_numpy(texture_basis[:, :n_tex]).float()[None, ...] self.register_buffer("texture_mean", texture_mean) self.register_buffer("texture_basis", texture_basis) @@ -64,10 +62,8 @@ class FLAMETex(nn.Module): texcode: [batchsize, n_tex] texture: [bz, 3, 256, 256], range: 0-1 """ - texture = self.texture_mean + (self.texture_basis * - texcode[:, None, :]).sum(-1) - texture = texture.reshape(texcode.shape[0], 512, 512, - 3).permute(0, 3, 1, 2) + texture = self.texture_mean + (self.texture_basis * texcode[:, None, :]).sum(-1) + texture = texture.reshape(texcode.shape[0], 512, 512, 3).permute(0, 3, 1, 2) texture = F.interpolate(texture, [256, 256]) texture = texture[:, [2, 1, 0], :, :] return texture @@ -78,13 +74,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture): TODO: pytorch version ==> grid sample """ if smplx_texture.shape[0] != smplx_texture.shape[1]: - print("SMPL-X texture not squared (%d != %d)" % - (smplx_texture[0], smplx_texture[1])) + print("SMPL-X texture not squared (%d != %d)" % (smplx_texture[0], smplx_texture[1])) return if smplx_texture.shape[0] != cached_data["target_resolution"]: print( - "SMPL-X texture size does not match cached image resolution (%d != %d)" - % (smplx_texture.shape[0], cached_data["target_resolution"])) + "SMPL-X texture size does not match cached image resolution (%d != %d)" % + (smplx_texture.shape[0], cached_data["target_resolution"]) + ) return x_coords = cached_data["x_coords"] y_coords = cached_data["y_coords"] @@ -98,11 +94,13 @@ def texture_flame2smplx(cached_data, flame_texture, smplx_texture): flame_texture.shape[0], ).astype(int) source_tex_coords[:, 1] = np.clip( - flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0, - flame_texture.shape[1]).astype(int) + flame_texture.shape[1] * (source_uv_points[:, 0]), 0.0, flame_texture.shape[1] + ).astype(int) smplx_texture[y_coords[target_pixel_ids].astype(int), - x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[ - source_tex_coords[:, 0], source_tex_coords[:, 1]] + x_coords[target_pixel_ids].astype(int), :, ] = flame_texture[source_tex_coords[:, + 0], + source_tex_coords[:, + 1]] return smplx_texture diff --git a/lib/pixielib/models/SMPLX.py b/lib/pixielib/models/SMPLX.py index 0e940d2d42bd07484943f13519df97e6cf3e0fa8..beb672facebc7fa9e61eee7f8e7f3f185ac6cdad 100644 --- a/lib/pixielib/models/SMPLX.py +++ b/lib/pixielib/models/SMPLX.py @@ -209,452 +209,468 @@ extra_names = [ SMPLX_names += extra_names part_indices = {} -part_indices["body"] = np.array([ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 123, - 124, - 125, - 126, - 127, - 132, - 134, - 135, - 136, - 137, - 138, - 143, -]) -part_indices["torso"] = np.array([ - 0, - 1, - 2, - 3, - 6, - 9, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 22, - 23, - 24, - 55, - 56, - 57, - 58, - 59, - 76, - 77, - 78, - 79, - 80, - 81, - 82, - 83, - 84, - 85, - 86, - 87, - 88, - 89, - 90, - 91, - 92, - 93, - 94, - 95, - 96, - 97, - 98, - 99, - 100, - 101, - 102, - 103, - 104, - 105, - 106, - 107, - 108, - 109, - 110, - 111, - 112, - 113, - 114, - 115, - 116, - 117, - 118, - 119, - 120, - 121, - 122, - 123, - 124, - 125, - 126, - 127, - 128, - 129, - 130, - 131, - 132, - 133, - 134, - 135, - 136, - 137, - 138, - 139, - 140, - 141, - 142, - 143, - 144, -]) -part_indices["head"] = np.array([ - 12, - 15, - 22, - 23, - 24, - 55, - 56, - 57, - 58, - 59, - 60, - 61, - 62, - 63, - 64, - 65, - 66, - 67, - 68, - 69, - 70, - 71, - 72, - 73, - 74, - 75, - 76, - 77, - 78, - 79, - 80, - 81, - 82, - 83, - 84, - 85, - 86, - 87, - 88, - 89, - 90, - 91, - 92, - 93, - 94, - 95, - 96, - 97, - 98, - 99, - 100, - 101, - 102, - 103, - 104, - 105, - 106, - 107, - 108, - 109, - 110, - 111, - 112, - 113, - 114, - 115, - 116, - 117, - 118, - 119, - 120, - 121, - 122, - 123, - 125, - 126, - 134, - 136, - 137, -]) -part_indices["face"] = np.array([ - 55, - 56, - 57, - 58, - 59, - 60, - 61, - 62, - 63, - 64, - 65, - 66, - 67, - 68, - 69, - 70, - 71, - 72, - 73, - 74, - 75, - 76, - 77, - 78, - 79, - 80, - 81, - 82, - 83, - 84, - 85, - 86, - 87, - 88, - 89, - 90, - 91, - 92, - 93, - 94, - 95, - 96, - 97, - 98, - 99, - 100, - 101, - 102, - 103, - 104, - 105, - 106, - 107, - 108, - 109, - 110, - 111, - 112, - 113, - 114, - 115, - 116, - 117, - 118, - 119, - 120, - 121, - 122, -]) -part_indices["upper"] = np.array([ - 12, - 13, - 14, - 55, - 56, - 57, - 58, - 59, - 60, - 61, - 62, - 63, - 64, - 65, - 66, - 67, - 68, - 69, - 70, - 71, - 72, - 73, - 74, - 75, - 76, - 77, - 78, - 79, - 80, - 81, - 82, - 83, - 84, - 85, - 86, - 87, - 88, - 89, - 90, - 91, - 92, - 93, - 94, - 95, - 96, - 97, - 98, - 99, - 100, - 101, - 102, - 103, - 104, - 105, - 106, - 107, - 108, - 109, - 110, - 111, - 112, - 113, - 114, - 115, - 116, - 117, - 118, - 119, - 120, - 121, - 122, -]) -part_indices["hand"] = np.array([ - 20, - 21, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39, - 40, - 41, - 42, - 43, - 44, - 45, - 46, - 47, - 48, - 49, - 50, - 51, - 52, - 53, - 54, - 128, - 129, - 130, - 131, - 133, - 139, - 140, - 141, - 142, - 144, -]) -part_indices["left_hand"] = np.array([ - 20, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39, - 128, - 129, - 130, - 131, - 133, -]) -part_indices["right_hand"] = np.array([ - 21, - 40, - 41, - 42, - 43, - 44, - 45, - 46, - 47, - 48, - 49, - 50, - 51, - 52, - 53, - 54, - 139, - 140, - 141, - 142, - 144, -]) +part_indices["body"] = np.array( + [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 123, + 124, + 125, + 126, + 127, + 132, + 134, + 135, + 136, + 137, + 138, + 143, + ] +) +part_indices["torso"] = np.array( + [ + 0, + 1, + 2, + 3, + 6, + 9, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 22, + 23, + 24, + 55, + 56, + 57, + 58, + 59, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + ] +) +part_indices["head"] = np.array( + [ + 12, + 15, + 22, + 23, + 24, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 125, + 126, + 134, + 136, + 137, + ] +) +part_indices["face"] = np.array( + [ + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + ] +) +part_indices["upper"] = np.array( + [ + 12, + 13, + 14, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + ] +) +part_indices["hand"] = np.array( + [ + 20, + 21, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 128, + 129, + 130, + 131, + 133, + 139, + 140, + 141, + 142, + 144, + ] +) +part_indices["left_hand"] = np.array( + [ + 20, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 128, + 129, + 130, + 131, + 133, + ] +) +part_indices["right_hand"] = np.array( + [ + 21, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 139, + 140, + 141, + 142, + 144, + ] +) # kinematic tree head_kin_chain = [15, 12, 9, 6, 3, 0] @@ -691,13 +707,12 @@ class SMPLX(nn.Module): Given smplx parameters, this class generates a differentiable SMPLX function which outputs a mesh and 3D joints """ - def __init__(self, config): super(SMPLX, self).__init__() # print("creating the SMPLX Decoder") ss = np.load(config.smplx_model_path, allow_pickle=True) smplx_model = Struct(**ss) - + self.dtype = torch.float32 self.register_buffer( "faces_tensor", @@ -705,8 +720,8 @@ class SMPLX(nn.Module): ) # The vertices of the template model self.register_buffer( - "v_template", - to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)) + "v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype) + ) # The shape components and expression # expression space is the same as FLAME shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) @@ -721,21 +736,18 @@ class SMPLX(nn.Module): # The pose components num_pose_basis = smplx_model.posedirs.shape[-1] posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T - self.register_buffer("posedirs", - to_tensor(to_np(posedirs), dtype=self.dtype)) + self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype)) self.register_buffer( - "J_regressor", - to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)) + "J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype) + ) parents = to_tensor(to_np(smplx_model.kintree_table[0])).long() parents[0] = -1 self.register_buffer("parents", parents) - self.register_buffer( - "lbs_weights", - to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) + self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) # for face keypoints self.register_buffer( - "lmk_faces_idx", - torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)) + "lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long) + ) self.register_buffer( "lmk_bary_coords", torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype), @@ -746,24 +758,20 @@ class SMPLX(nn.Module): ) self.register_buffer( "dynamic_lmk_bary_coords", - torch.tensor(smplx_model.dynamic_lmk_bary_coords, - dtype=self.dtype), + torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype), ) # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks - self.register_buffer("head_kin_chain", - torch.tensor(head_kin_chain, dtype=torch.long)) + self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long)) # -- initialize parameters # shape and expression self.register_buffer( "shape_params", - nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), - requires_grad=False), + nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False), ) self.register_buffer( "expression_params", - nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), - requires_grad=False), + nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False), ) # pose: represented as rotation matrx [number of joints, 3, 3] self.register_buffer( @@ -824,8 +832,7 @@ class SMPLX(nn.Module): ) if config.extra_joint_path: - self.extra_joint_selector = JointsFromVerticesSelector( - fname=config.extra_joint_path) + self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path) self.use_joint_regressor = True self.keypoint_names = SMPLX_names if self.use_joint_regressor: @@ -843,7 +850,8 @@ class SMPLX(nn.Module): self.register_buffer("target_idxs", torch.from_numpy(target)) self.register_buffer( "extra_joint_regressor", - torch.from_numpy(j14_regressor).to(torch.float32)) + torch.from_numpy(j14_regressor).to(torch.float32) + ) self.part_indices = part_indices def forward( @@ -880,23 +888,17 @@ class SMPLX(nn.Module): if expression_params is None: expression_params = self.expression_params.expand(batch_size, -1) if global_pose is None: - global_pose = self.global_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) + global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) if body_pose is None: - body_pose = self.body_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) + body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) if jaw_pose is None: - jaw_pose = self.jaw_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) + jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) if eye_pose is None: - eye_pose = self.eye_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) + eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) if left_hand_pose is None: - left_hand_pose = self.left_hand_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) + left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) if right_hand_pose is None: - right_hand_pose = self.right_hand_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) + right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) shape_components = torch.cat([shape_params, expression_params], dim=1) full_pose = torch.cat( @@ -910,8 +912,7 @@ class SMPLX(nn.Module): ], dim=1, ) - template_vertices = self.v_template.unsqueeze(0).expand( - batch_size, -1, -1) + template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) # smplx vertices, joints = lbs( shape_components, @@ -926,10 +927,8 @@ class SMPLX(nn.Module): pose2rot=False, ) # face dynamic landmarks - lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( - batch_size, -1) - lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand( - batch_size, -1, -1) + lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) + lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords( vertices, full_pose, @@ -939,14 +938,12 @@ class SMPLX(nn.Module): ) lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) - landmarks = vertices2landmarks(vertices, self.faces_tensor, - lmk_faces_idx, lmk_bary_coords) + landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) final_joint_set = [joints, landmarks] if hasattr(self, "extra_joint_selector"): # Add any extra joints that might be needed - extra_joints = self.extra_joint_selector(vertices, - self.faces_tensor) + extra_joints = self.extra_joint_selector(vertices, self.faces_tensor) final_joint_set.append(extra_joints) # Create the final joint set joints = torch.cat(final_joint_set, dim=1) @@ -978,16 +975,15 @@ class SMPLX(nn.Module): # -> Left elbow -> Left wrist kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] else: - raise NotImplementedError( - f"pose_abs2rel does not support: {abs_joint}") + raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}") batch_size = global_pose.shape[0] dtype = global_pose.dtype device = global_pose.device full_pose = torch.cat([global_pose, body_pose], dim=1) - rel_rot_mat = (torch.eye(3, device=device, - dtype=dtype).unsqueeze_(dim=0).repeat( - batch_size, 1, 1)) + rel_rot_mat = ( + torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1) + ) for idx in kin_chain[1:]: rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) @@ -1027,11 +1023,8 @@ class SMPLX(nn.Module): # -> Left elbow -> Left wrist kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] else: - raise NotImplementedError( - f"pose_rel2abs does not support: {abs_joint}") - rel_rot_mat = torch.eye(3, - device=full_pose.device, - dtype=full_pose.dtype).unsqueeze_(dim=0) + raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}") + rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0) for idx in kin_chain: rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat) abs_pose = rel_rot_mat[:, None, :, :] diff --git a/lib/pixielib/models/encoders.py b/lib/pixielib/models/encoders.py index 6b0d0e17cf1ca9dc7c87aeba5d8dc3df97f04011..0783c9265ab442a259fd693a55039026cc7608db 100755 --- a/lib/pixielib/models/encoders.py +++ b/lib/pixielib/models/encoders.py @@ -5,14 +5,13 @@ import torch.nn.functional as F class ResnetEncoder(nn.Module): - def __init__(self, append_layers=None): super(ResnetEncoder, self).__init__() from . import resnet # feature_size = 2048 self.feature_dim = 2048 - self.encoder = resnet.load_ResNet50Model() # out: 2048 + self.encoder = resnet.load_ResNet50Model() # out: 2048 # regressor self.append_layers = append_layers @@ -25,7 +24,6 @@ class ResnetEncoder(nn.Module): class MLP(nn.Module): - def __init__(self, channels=[2048, 1024, 1], last_op=None): super(MLP, self).__init__() layers = [] @@ -45,13 +43,12 @@ class MLP(nn.Module): class HRNEncoder(nn.Module): - def __init__(self, append_layers=None): super(HRNEncoder, self).__init__() from . import hrnet self.feature_dim = 2048 - self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048 + self.encoder = hrnet.load_HRNet(pretrained=True) # out: 2048 # regressor self.append_layers = append_layers diff --git a/lib/pixielib/models/hrnet.py b/lib/pixielib/models/hrnet.py index 158c3cc31189d488877f6d2884fab7dc65bc8815..c1fd871abf8ae79dd87f96e30d14d726c913db05 100644 --- a/lib/pixielib/models/hrnet.py +++ b/lib/pixielib/models/hrnet.py @@ -15,38 +15,42 @@ def load_HRNet(pretrained=False): hr_net_cfg_dict = { "use_old_impl": False, "pretrained_layers": ["*"], - "stage1": { - "num_modules": 1, - "num_branches": 1, - "num_blocks": [4], - "num_channels": [64], - "block": "BOTTLENECK", - "fuse_method": "SUM", - }, - "stage2": { - "num_modules": 1, - "num_branches": 2, - "num_blocks": [4, 4], - "num_channels": [48, 96], - "block": "BASIC", - "fuse_method": "SUM", - }, - "stage3": { - "num_modules": 4, - "num_branches": 3, - "num_blocks": [4, 4, 4], - "num_channels": [48, 96, 192], - "block": "BASIC", - "fuse_method": "SUM", - }, - "stage4": { - "num_modules": 3, - "num_branches": 4, - "num_blocks": [4, 4, 4, 4], - "num_channels": [48, 96, 192, 384], - "block": "BASIC", - "fuse_method": "SUM", - }, + "stage1": + { + "num_modules": 1, + "num_branches": 1, + "num_blocks": [4], + "num_channels": [64], + "block": "BOTTLENECK", + "fuse_method": "SUM", + }, + "stage2": + { + "num_modules": 1, + "num_branches": 2, + "num_blocks": [4, 4], + "num_channels": [48, 96], + "block": "BASIC", + "fuse_method": "SUM", + }, + "stage3": + { + "num_modules": 4, + "num_branches": 3, + "num_blocks": [4, 4, 4], + "num_channels": [48, 96, 192], + "block": "BASIC", + "fuse_method": "SUM", + }, + "stage4": + { + "num_modules": 3, + "num_branches": 4, + "num_blocks": [4, 4, 4, 4], + "num_channels": [48, 96, 192, 384], + "block": "BASIC", + "fuse_method": "SUM", + }, } hr_net_cfg = hr_net_cfg_dict model = HighResolutionNet(hr_net_cfg) @@ -55,7 +59,6 @@ def load_HRNet(pretrained=False): class HighResolutionModule(nn.Module): - def __init__( self, num_branches, @@ -67,8 +70,7 @@ class HighResolutionModule(nn.Module): multi_scale_output=True, ): super(HighResolutionModule, self).__init__() - self._check_branches(num_branches, blocks, num_blocks, num_inchannels, - num_channels) + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method @@ -76,37 +78,33 @@ class HighResolutionModule(nn.Module): self.multi_scale_output = multi_scale_output - self.branches = self._make_branches(num_branches, blocks, num_blocks, - num_channels) + self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(True) - def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, - num_channels): + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): - error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( - num_branches, len(num_blocks)) + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks)) raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( - num_branches, len(num_channels)) + num_branches, len(num_channels) + ) raise ValueError(error_msg) if num_branches != len(num_inchannels): error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( - num_branches, len(num_inchannels)) + num_branches, len(num_inchannels) + ) raise ValueError(error_msg) - def _make_one_branch(self, - branch_index, - block, - num_blocks, - num_channels, - stride=1): + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None - if (stride != 1 or self.num_inchannels[branch_index] != - num_channels[branch_index] * block.expansion): + if ( + stride != 1 or + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion + ): downsample = nn.Sequential( nn.Conv2d( self.num_inchannels[branch_index], @@ -115,8 +113,7 @@ class HighResolutionModule(nn.Module): stride=stride, bias=False, ), - nn.BatchNorm2d(num_channels[branch_index] * block.expansion, - momentum=BN_MOMENTUM), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] @@ -126,13 +123,11 @@ class HighResolutionModule(nn.Module): num_channels[branch_index], stride, downsample, - )) - self.num_inchannels[ - branch_index] = num_channels[branch_index] * block.expansion + ) + ) + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): - layers.append( - block(self.num_inchannels[branch_index], - num_channels[branch_index])) + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) @@ -140,8 +135,7 @@ class HighResolutionModule(nn.Module): branches = [] for i in range(num_branches): - branches.append( - self._make_one_branch(i, block, num_blocks, num_channels)) + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) @@ -167,9 +161,9 @@ class HighResolutionModule(nn.Module): bias=False, ), nn.BatchNorm2d(num_inchannels[i]), - nn.Upsample(scale_factor=2**(j - i), - mode="nearest"), - )) + nn.Upsample(scale_factor=2**(j - i), mode="nearest"), + ) + ) elif j == i: fuse_layer.append(None) else: @@ -188,7 +182,8 @@ class HighResolutionModule(nn.Module): bias=False, ), nn.BatchNorm2d(num_outchannels_conv3x3), - )) + ) + ) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append( @@ -203,7 +198,8 @@ class HighResolutionModule(nn.Module): ), nn.BatchNorm2d(num_outchannels_conv3x3), nn.ReLU(True), - )) + ) + ) fuse_layer.append(nn.Sequential(*conv3x3s)) fuse_layers.append(nn.ModuleList(fuse_layer)) @@ -237,7 +233,6 @@ blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck} class HighResolutionNet(nn.Module): - def __init__(self, cfg, **kwargs): self.inplanes = 64 super(HighResolutionNet, self).__init__() @@ -245,19 +240,9 @@ class HighResolutionNet(nn.Module): self.use_old_impl = use_old_impl # stem net - self.conv1 = nn.Conv2d(3, - 64, - kernel_size=3, - stride=2, - padding=1, - bias=False) + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) - self.conv2 = nn.Conv2d(64, - 64, - kernel_size=3, - stride=2, - padding=1, - bias=False) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) @@ -271,41 +256,29 @@ class HighResolutionNet(nn.Module): self.stage2_cfg = cfg.get("stage2", {}) num_channels = self.stage2_cfg.get("num_channels", (32, 64)) block = blocks_dict[self.stage2_cfg.get("block")] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] stage2_num_channels = num_channels - self.transition1 = self._make_transition_layer([stage1_out_channel], - num_channels) - self.stage2, pre_stage_channels = self._make_stage( - self.stage2_cfg, num_channels) + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) self.stage3_cfg = cfg.get("stage3") num_channels = self.stage3_cfg["num_channels"] block = blocks_dict[self.stage3_cfg["block"]] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] stage3_num_channels = num_channels - self.transition2 = self._make_transition_layer(pre_stage_channels, - num_channels) - self.stage3, pre_stage_channels = self._make_stage( - self.stage3_cfg, num_channels) + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) self.stage4_cfg = cfg.get("stage4") num_channels = self.stage4_cfg["num_channels"] block = blocks_dict[self.stage4_cfg["block"]] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] - self.transition3 = self._make_transition_layer(pre_stage_channels, - num_channels) + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) stage_4_out_channels = num_channels self.stage4, pre_stage_channels = self._make_stage( - self.stage4_cfg, - num_channels, - multi_scale_output=not self.use_old_impl) + self.stage4_cfg, num_channels, multi_scale_output=not self.use_old_impl + ) stage4_num_channels = num_channels self.output_channels_dim = pre_stage_channels @@ -316,35 +289,34 @@ class HighResolutionNet(nn.Module): self.avg_pooling = nn.AdaptiveAvgPool2d(1) if use_old_impl: - in_dims = (2**2 * stage2_num_channels[-1] + - 2**1 * stage3_num_channels[-1] + - stage_4_out_channels[-1]) + in_dims = ( + 2**2 * stage2_num_channels[-1] + 2**1 * stage3_num_channels[-1] + + stage_4_out_channels[-1] + ) else: # TODO: Replace with parameters in_dims = 4 * 384 self.subsample_4 = self._make_subsample_layer( - in_channels=stage4_num_channels[0], num_layers=3) + in_channels=stage4_num_channels[0], num_layers=3 + ) self.subsample_3 = self._make_subsample_layer( - in_channels=stage2_num_channels[-1], num_layers=2) + in_channels=stage2_num_channels[-1], num_layers=2 + ) self.subsample_2 = self._make_subsample_layer( - in_channels=stage3_num_channels[-1], num_layers=1) - self.conv_layers = self._make_conv_layer(in_channels=in_dims, - num_layers=5) + in_channels=stage3_num_channels[-1], num_layers=1 + ) + self.conv_layers = self._make_conv_layer(in_channels=in_dims, num_layers=5) def get_output_dim(self): - base_output = { - f"layer{idx + 1}": val - for idx, val in enumerate(self.output_channels_dim) - } + base_output = {f"layer{idx + 1}": val for idx, val in enumerate(self.output_channels_dim)} output = base_output.copy() for key in base_output: output[f"{key}_avg_pooling"] = output[key] output["concat"] = 2048 return output - def _make_transition_layer(self, num_channels_pre_layer, - num_channels_cur_layer): + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) @@ -364,26 +336,24 @@ class HighResolutionNet(nn.Module): ), nn.BatchNorm2d(num_channels_cur_layer[i]), nn.ReLU(inplace=True), - )) + ) + ) else: transition_layers.append(None) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] - outchannels = (num_channels_cur_layer[i] if j == i - - num_branches_pre else inchannels) + outchannels = ( + num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + ) conv3x3s.append( nn.Sequential( - nn.Conv2d(inchannels, - outchannels, - 3, - 2, - 1, - bias=False), + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), nn.BatchNorm2d(outchannels), nn.ReLU(inplace=True), - )) + ) + ) transition_layers.append(nn.Sequential(*conv3x3s)) return nn.ModuleList(transition_layers) @@ -410,24 +380,13 @@ class HighResolutionNet(nn.Module): return nn.Sequential(*layers) - def _make_conv_layer(self, - in_channels=2048, - num_layers=3, - num_filters=2048, - stride=1): + def _make_conv_layer(self, in_channels=2048, num_layers=3, num_filters=2048, stride=1): layers = [] for i in range(num_layers): - downsample = nn.Conv2d(in_channels, - num_filters, - stride=1, - kernel_size=1, - bias=False) - layers.append( - Bottleneck(in_channels, - num_filters // 4, - downsample=downsample)) + downsample = nn.Conv2d(in_channels, num_filters, stride=1, kernel_size=1, bias=False) + layers.append(Bottleneck(in_channels, num_filters // 4, downsample=downsample)) in_channels = num_filters return nn.Sequential(*layers) @@ -444,18 +403,15 @@ class HighResolutionNet(nn.Module): kernel_size=3, stride=stride, padding=1, - )) + ) + ) in_channels = 2 * in_channels layers.append(nn.BatchNorm2d(in_channels, momentum=BN_MOMENTUM)) layers.append(nn.ReLU(inplace=True)) return nn.Sequential(*layers) - def _make_stage(self, - layer_config, - num_inchannels, - multi_scale_output=True, - log=False): + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True, log=False): num_modules = layer_config["num_modules"] num_branches = layer_config["num_branches"] num_blocks = layer_config["num_blocks"] @@ -480,7 +436,8 @@ class HighResolutionNet(nn.Module): num_channels, fuse_method, reset_multi_scale_output, - )) + ) + ) modules[-1].log = log num_inchannels = modules[-1].get_num_inchannels() @@ -580,15 +537,14 @@ class HighResolutionNet(nn.Module): def load_weights(self, pretrained=""): pretrained = osp.expandvars(pretrained) if osp.isfile(pretrained): - pretrained_state_dict = torch.load( - pretrained, map_location=torch.device("cpu")) + pretrained_state_dict = torch.load(pretrained, map_location=torch.device("cpu")) need_init_state_dict = {} for name, m in pretrained_state_dict.items(): - if (name.split(".")[0] in self.pretrained_layers - or self.pretrained_layers[0] == "*"): + if ( + name.split(".")[0] in self.pretrained_layers or self.pretrained_layers[0] == "*" + ): need_init_state_dict[name] = m - missing, unexpected = self.load_state_dict(need_init_state_dict, - strict=False) + missing, unexpected = self.load_state_dict(need_init_state_dict, strict=False) elif pretrained: raise ValueError("{} is not exist!".format(pretrained)) diff --git a/lib/pixielib/models/lbs.py b/lib/pixielib/models/lbs.py index 2b5f9a648408f3a83670b9bce94f7a7a08de37ae..a2252a9a81c7e9ca3633a02cc08f3fafd5bd22cc 100755 --- a/lib/pixielib/models/lbs.py +++ b/lib/pixielib/models/lbs.py @@ -30,8 +30,7 @@ def rot_mat_to_euler(rot_mats): # Calculates rotation matrix to euler angles # Careful for extreme cases of eular angles like [0.0, pi, 0.0] - sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + - rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) return torch.atan2(-rot_mats[:, 2, 0], sy) @@ -86,15 +85,13 @@ def find_dynamic_lmk_idx_and_bcoords( # aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) rot_mats = torch.index_select(pose, 1, head_kin_chain) - rel_rot_mat = torch.eye(3, device=vertices.device, - dtype=dtype).unsqueeze_(dim=0) + rel_rot_mat = torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0) for idx in range(len(head_kin_chain)): # rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat) - y_rot_angle = torch.round( - torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, - max=39)).to(dtype=torch.long) + y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) # print(y_rot_angle[0]) neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) mask = y_rot_angle.lt(-39).to(dtype=torch.long) @@ -102,8 +99,7 @@ def find_dynamic_lmk_idx_and_bcoords( y_rot_angle = neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle # print(y_rot_angle[0]) - dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, - y_rot_angle) + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle) dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) return dyn_lmk_faces_idx, dyn_lmk_b_coords @@ -135,11 +131,11 @@ def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): batch_size, num_verts = vertices.shape[:2] device = vertices.device - lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( - batch_size, -1, 3) + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(batch_size, -1, 3) - lmk_faces += (torch.arange(batch_size, dtype=torch.long, - device=device).view(-1, 1, 1) * num_verts) + lmk_faces += ( + torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts + ) lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) @@ -211,13 +207,11 @@ def lbs( # N x J x 3 x 3 ident = torch.eye(3, dtype=dtype, device=device) if pose2rot: - rot_mats = batch_rodrigues(pose.view(-1, 3), - dtype=dtype).view([batch_size, -1, 3, 3]) + rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) # (N x P) x (P, V * 3) -> N x V x 3 - pose_offsets = torch.matmul(pose_feature, - posedirs).view(batch_size, -1, 3) + pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3) else: pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident rot_mats = pose.view(batch_size, -1, 3, 3) @@ -234,12 +228,9 @@ def lbs( W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) # (N x V x (J + 1)) x (N x (J + 1) x 16) num_joints = J_regressor.shape[0] - T = torch.matmul(W, A.view(batch_size, num_joints, - 16)).view(batch_size, -1, 4, 4) + T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4) - homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], - dtype=dtype, - device=device) + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device) v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) @@ -318,8 +309,7 @@ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) - K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], - dim=1).view((batch_size, 3, 3)) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) @@ -335,9 +325,7 @@ def transform_mat(R, t): - T: Bx4x4 Transformation matrix """ # No padding left or right, only add an extra row - return torch.cat([F.pad(R, [0, 0, 0, 1]), - F.pad(t, [0, 0, 0, 1], value=1)], - dim=2) + return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): @@ -370,15 +358,13 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): rel_joints[:, 1:] -= joints[:, parents[1:]] transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), - rel_joints.reshape(-1, 3, 1)).reshape( - -1, joints.shape[1], 4, 4) + rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) transform_chain = [transforms_mat[:, 0]] for i in range(1, parents.shape[0]): # Subtract the joint location at the rest pose # No need for rotation, since it's identity when at rest - curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, - i]) + curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i]) transform_chain.append(curr_res) transforms = torch.stack(transform_chain, dim=1) @@ -392,21 +378,22 @@ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): joints_homogen = F.pad(joints, [0, 0, 0, 1]) rel_transforms = transforms - F.pad( - torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0] + ) return posed_joints, rel_transforms class JointsFromVerticesSelector(nn.Module): - def __init__(self, fname): """Selects extra joints from vertices""" super(JointsFromVerticesSelector, self).__init__() err_msg = ("Either pass a filename or triangle face ids, names and" " barycentrics") - assert fname is not None or (face_ids is not None and bcs is not None - and names is not None), err_msg + assert fname is not None or ( + face_ids is not None and bcs is not None and names is not None + ), err_msg if fname is not None: fname = os.path.expanduser(os.path.expandvars(fname)) with open(fname, "r") as f: @@ -422,13 +409,11 @@ class JointsFromVerticesSelector(nn.Module): assert len(bcs) == len( face_ids ), "The number of barycentric coordinates must be equal to the faces" - assert len(names) == len( - face_ids), "The number of names must be equal to the number of " + assert len(names) == len(face_ids), "The number of names must be equal to the number of " self.names = names self.register_buffer("bcs", torch.tensor(bcs, dtype=torch.float32)) - self.register_buffer("face_ids", - torch.tensor(face_ids, dtype=torch.long)) + self.register_buffer("face_ids", torch.tensor(face_ids, dtype=torch.long)) def extra_joint_names(self): """Returns the names of the extra joints""" @@ -439,8 +424,7 @@ class JointsFromVerticesSelector(nn.Module): return [] vertex_ids = faces[self.face_ids].reshape(-1) # Should be BxNx3x3 - triangles = torch.index_select(vertices, 1, vertex_ids).reshape( - -1, len(self.bcs), 3, 3) + triangles = torch.index_select(vertices, 1, vertex_ids).reshape(-1, len(self.bcs), 3, 3) return (triangles * self.bcs[None, :, :, None]).sum(dim=2) @@ -463,7 +447,6 @@ def to_np(array, dtype=np.float32): class Struct(object): - def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) diff --git a/lib/pixielib/models/moderators.py b/lib/pixielib/models/moderators.py index 8a14c472530787be97045a4e620e28cae051df65..3ab139ac2ad3e0cbd99c8e40dbf6136a37e53cb5 100644 --- a/lib/pixielib/models/moderators.py +++ b/lib/pixielib/models/moderators.py @@ -12,11 +12,7 @@ import torch.nn.functional as F class TempSoftmaxFusion(nn.Module): - - def __init__(self, - channels=[2048 * 2, 1024, 1], - detach_inputs=False, - detach_feature=False): + def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False): super(TempSoftmaxFusion, self).__init__() self.detach_inputs = detach_inputs self.detach_feature = detach_feature @@ -63,11 +59,7 @@ class TempSoftmaxFusion(nn.Module): class GumbelSoftmaxFusion(nn.Module): - - def __init__(self, - channels=[2048 * 2, 1024, 1], - detach_inputs=False, - detach_feature=False): + def __init__(self, channels=[2048 * 2, 1024, 1], detach_inputs=False, detach_feature=False): super(GumbelSoftmaxFusion, self).__init__() self.detach_inputs = detach_inputs self.detach_feature = detach_feature diff --git a/lib/pixielib/models/resnet.py b/lib/pixielib/models/resnet.py index 72bbf174a9b0ff9a9d75010e70e8059326fb72e3..162bc655bff1bd3ca2058334de2e15660de8f5f5 100755 --- a/lib/pixielib/models/resnet.py +++ b/lib/pixielib/models/resnet.py @@ -22,16 +22,10 @@ from torchvision import models class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(3, - 64, - kernel_size=7, - stride=2, - padding=3, - bias=False) + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -98,12 +92,7 @@ class Bottleneck(nn.Module): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, - planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -136,12 +125,7 @@ class Bottleneck(nn.Module): def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False) + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): @@ -196,8 +180,7 @@ def load_ResNet50Model(): model = ResNet(Bottleneck, [3, 4, 6, 3]) copy_parameter_from_resnet( model, - torchvision.models.resnet50( - weights=models.ResNet50_Weights.DEFAULT).state_dict(), + torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT).state_dict(), ) return model @@ -206,8 +189,7 @@ def load_ResNet101Model(): model = ResNet(Bottleneck, [3, 4, 23, 3]) copy_parameter_from_resnet( model, - torchvision.models.resnet101( - weights=models.ResNet101_Weights.DEFAULT).state_dict(), + torchvision.models.resnet101(weights=models.ResNet101_Weights.DEFAULT).state_dict(), ) return model @@ -216,8 +198,7 @@ def load_ResNet152Model(): model = ResNet(Bottleneck, [3, 8, 36, 3]) copy_parameter_from_resnet( model, - torchvision.models.resnet152( - weights=models.ResNet152_Weights.DEFAULT).state_dict(), + torchvision.models.resnet152(weights=models.ResNet152_Weights.DEFAULT).state_dict(), ) return model @@ -229,7 +210,6 @@ def load_ResNet152Model(): class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" - def __init__(self, in_channels, out_channels): super().__init__() self.double_conv = nn.Sequential( @@ -247,11 +227,9 @@ class DoubleConv(nn.Module): class Down(nn.Module): """Downscaling with maxpool then double conv""" - def __init__(self, in_channels, out_channels): super().__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) + self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) def forward(self, x): return self.maxpool_conv(x) @@ -259,20 +237,16 @@ class Down(nn.Module): class Up(nn.Module): """Upscaling then double conv""" - def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: - self.up = nn.Upsample(scale_factor=2, - mode="bilinear", - align_corners=True) + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) else: - self.up = nn.ConvTranspose2d(in_channels // 2, - in_channels // 2, - kernel_size=2, - stride=2) + self.up = nn.ConvTranspose2d( + in_channels // 2, in_channels // 2, kernel_size=2, stride=2 + ) self.conv = DoubleConv(in_channels, out_channels) @@ -282,9 +256,7 @@ class Up(nn.Module): diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] - x1 = F.pad( - x1, - [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd @@ -293,7 +265,6 @@ class Up(nn.Module): class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) @@ -303,7 +274,6 @@ class OutConv(nn.Module): class UNet(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=True): super(UNet, self).__init__() self.n_channels = n_channels diff --git a/lib/pixielib/pixie.py b/lib/pixielib/pixie.py index 575ec41079d5867ea4e3cce8968eb6b5e0bb4e95..545bc46f92b73aff4037da7ff3c6ebeba2b4c361 100644 --- a/lib/pixielib/pixie.py +++ b/lib/pixielib/pixie.py @@ -33,7 +33,6 @@ from .utils.config import cfg class PIXIE(object): - def __init__(self, config=None, device="cuda:0"): if config is None: self.cfg = cfg @@ -45,10 +44,7 @@ class PIXIE(object): self.param_list_dict = {} for lst in self.cfg.params.keys(): param_list = cfg.params.get(lst) - self.param_list_dict[lst] = { - i: cfg.model.get("n_" + i) - for i in param_list - } + self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list} # Build the models self._create_model() @@ -97,24 +93,19 @@ class PIXIE(object): self.Regressor = {} for key in self.cfg.network.regressor.keys(): n_output = sum(self.param_list_dict[f"{key}_list"].values()) - channels = ([2048] + self.cfg.network.regressor.get(key).channels + - [n_output]) + channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output]) if self.cfg.network.regressor.get(key).type == "mlp": self.Regressor[key] = MLP(channels=channels).to(self.device) - self.model_dict[f"Regressor_{key}"] = self.Regressor[ - key].state_dict() + self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict() # Build the extractors # to extract separate head/left hand/right hand feature from body feature self.Extractor = {} for key in self.cfg.network.extractor.keys(): - channels = [ - 2048 - ] + self.cfg.network.extractor.get(key).channels + [2048] + channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048] if self.cfg.network.extractor.get(key).type == "mlp": self.Extractor[key] = MLP(channels=channels).to(self.device) - self.model_dict[f"Extractor_{key}"] = self.Extractor[ - key].state_dict() + self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict() # Build the moderators self.Moderator = {} @@ -122,15 +113,13 @@ class PIXIE(object): share_part = key.split("_")[0] detach_inputs = self.cfg.network.moderator.get(key).detach_inputs detach_feature = self.cfg.network.moderator.get(key).detach_feature - channels = [2048 * 2 - ] + self.cfg.network.moderator.get(key).channels + [2] + channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2] self.Moderator[key] = TempSoftmaxFusion( detach_inputs=detach_inputs, detach_feature=detach_feature, channels=channels, ).to(self.device) - self.model_dict[f"Moderator_{key}"] = self.Moderator[ - key].state_dict() + self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict() # Build the SMPL-X body model, which we also use to represent faces and # hands, using the relevant parts only @@ -147,9 +136,7 @@ class PIXIE(object): print(f"pixie trained model path: {model_path} does not exist!") exit() # eval mode - for module in [ - self.Encoder, self.Regressor, self.Moderator, self.Extractor - ]: + for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]: for net in module.values(): net.eval() @@ -185,14 +172,14 @@ class PIXIE(object): # crop cropper_key = "hand" if "hand" in part_key else part_key points_scale = image.shape[-2:] - cropped_image, tform = self.Cropper[cropper_key].crop( - image, points_for_crop, points_scale) + cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale) # transform points(must be normalized to [-1.1]) accordingly cropped_points_dict = {} for points_key in points_dict.keys(): points = points_dict[points_key] cropped_points = self.Cropper[cropper_key].transform_points( - points, tform, points_scale, normalize=True) + points, tform, points_scale, normalize=True + ) cropped_points_dict[points_key] = cropped_points return cropped_image, cropped_points_dict @@ -244,8 +231,7 @@ class PIXIE(object): # then predict share parameters feature[key][f"{key}_share"] = feature[key][key] share_dict = self.decompose_code( - self.Regressor[f"{part}_share"]( - feature[key][f"{part}_share"]), + self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]), self.param_list_dict[f"{part}_share_list"], ) # compose parameters @@ -257,13 +243,16 @@ class PIXIE(object): f_body = feature["body"]["body"] # extract part feature for part_name in ["head", "left_hand", "right_hand"]: - feature["body"][f"{part_name}_share"] = self.Extractor[ - f"{part_name}_share"](f_body) + feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"]( + f_body + ) # -- check if part crops are given, if not, crop parts by coarse body estimation - if ("head_image" not in data[key].keys() - or "left_hand_image" not in data[key].keys() - or "right_hand_image" not in data[key].keys()): + if ( + "head_image" not in data[key].keys() or + "left_hand_image" not in data[key].keys() or + "right_hand_image" not in data[key].keys() + ): # - run without fusion to get coarse estimation, for cropping parts # body only body_dict = self.decompose_code( @@ -272,29 +261,26 @@ class PIXIE(object): ) # head share head_share_dict = self.decompose_code( - self.Regressor["head" + "_share"]( - feature[key]["head" + "_share"]), + self.Regressor["head" + "_share"](feature[key]["head" + "_share"]), self.param_list_dict["head" + "_share_list"], ) # right hand share right_hand_share_dict = self.decompose_code( - self.Regressor["hand" + "_share"]( - feature[key]["right_hand" + "_share"]), + self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]), self.param_list_dict["hand" + "_share_list"], ) # left hand share left_hand_share_dict = self.decompose_code( - self.Regressor["hand" + "_share"]( - feature[key]["left_hand" + "_share"]), + self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]), self.param_list_dict["hand" + "_share_list"], ) # change the dict name from right to left - left_hand_share_dict[ - "left_hand_pose"] = left_hand_share_dict.pop( - "right_hand_pose") - left_hand_share_dict[ - "left_wrist_pose"] = left_hand_share_dict.pop( - "right_wrist_pose") + left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop( + "right_hand_pose" + ) + left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop( + "right_wrist_pose" + ) param_dict[key] = { **body_dict, **head_share_dict, @@ -304,21 +290,18 @@ class PIXIE(object): if body_only: param_dict["moderator_weight"] = None return param_dict - prediction_body_only = self.decode(param_dict[key], - param_type="body") + prediction_body_only = self.decode(param_dict[key], param_type="body") # crop for part_name in ["head", "left_hand", "right_hand"]: part = part_name.split("_")[-1] points_dict = { - "smplx_kpt": - prediction_body_only["smplx_kpt"], - "trans_verts": - prediction_body_only["transformed_vertices"], + "smplx_kpt": prediction_body_only["smplx_kpt"], + "trans_verts": prediction_body_only["transformed_vertices"], } - image_hd = torchvision.transforms.Resize(1024)( - data["body"]["image"]) + image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"]) cropped_image, cropped_joints_dict = self.part_from_body( - image_hd, part_name, points_dict) + image_hd, part_name, points_dict + ) data[key][part_name + "_image"] = cropped_image # -- encode features from part crops, then fuse feature using the weight from moderator @@ -338,16 +321,12 @@ class PIXIE(object): self.Regressor[f"{part}_share"](f_part), self.param_list_dict[f"{part}_share_list"], ) - param_dict["body_" + part_name] = { - **part_dict, - **part_share_dict - } + param_dict["body_" + part_name] = {**part_dict, **part_share_dict} # moderator to assign weight, then integrate features - f_body_out, f_part_out, f_weight = self.Moderator[ - f"{part}_share"](feature["body"][f"{part_name}_share"], - f_part, - work=True) + f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"]( + feature["body"][f"{part_name}_share"], f_part, work=True + ) if copy_and_paste: # copy and paste strategy always trusts the results from part feature["body"][f"{part_name}_share"] = f_part @@ -355,8 +334,9 @@ class PIXIE(object): # for hand, if part weight > 0.7 (very confident, then fully trust part) part_w = f_weight[:, [1]] part_w[part_w > 0.7] = 1.0 - f_body_out = (feature["body"][f"{part_name}_share"] * - (1.0 - part_w) + f_part * part_w) + f_body_out = ( + feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w + ) feature["body"][f"{part_name}_share"] = f_body_out else: feature["body"][f"{part_name}_share"] = f_body_out @@ -367,29 +347,24 @@ class PIXIE(object): # -- predict parameters from fused body feature # head share head_share_dict = self.decompose_code( - self.Regressor["head" + "_share"](feature[key]["head" + - "_share"]), + self.Regressor["head" + "_share"](feature[key]["head" + "_share"]), self.param_list_dict["head" + "_share_list"], ) # right hand share right_hand_share_dict = self.decompose_code( - self.Regressor["hand" + "_share"]( - feature[key]["right_hand" + "_share"]), + self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]), self.param_list_dict["hand" + "_share_list"], ) # left hand share left_hand_share_dict = self.decompose_code( - self.Regressor["hand" + "_share"]( - feature[key]["left_hand" + "_share"]), + self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]), self.param_list_dict["hand" + "_share_list"], ) # change the dict name from right to left - left_hand_share_dict[ - "left_hand_pose"] = left_hand_share_dict.pop( - "right_hand_pose") - left_hand_share_dict[ - "left_wrist_pose"] = left_hand_share_dict.pop( - "right_wrist_pose") + left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose") + left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop( + "right_wrist_pose" + ) param_dict["body"] = { **body_dict, **head_share_dict, @@ -403,10 +378,10 @@ class PIXIE(object): if keep_local: # for local change that will not affect whole body and produce unnatral pose, trust part param_dict[key]["exp"] = param_dict["body_head"]["exp"] - param_dict[key]["right_hand_pose"] = param_dict[ - "body_right_hand"]["right_hand_pose"] - param_dict[key]["left_hand_pose"] = param_dict[ - "body_left_hand"]["right_hand_pose"] + param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][ + "right_hand_pose"] + param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][ + "right_hand_pose"] return param_dict @@ -426,75 +401,70 @@ class PIXIE(object): if "pose" in key and "jaw" not in key: param_dict[key] = converter.batch_cont2matrix(param_dict[key]) if param_type == "body" or param_type == "head": - param_dict["jaw_pose"] = converter.batch_euler2matrix( - param_dict["jaw_pose"])[:, None, :, :] + param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"] + )[:, None, :, :] # complement params if it's not in given param dict if param_type == "head": batch_size = param_dict["shape"].shape[0] param_dict["abs_head_pose"] = param_dict["head_pose"].clone() param_dict["global_pose"] = param_dict["head_pose"] - param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze( - 0).expand( - batch_size, -1, -1, - -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]] + param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + )[:, :self.param_list_dict["body_list"]["partbody_pose"]] param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) - param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze( - 0).expand(batch_size, -1, -1, -1) - param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze( - 0).expand(batch_size, -1, -1, -1) - param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze( - 0).expand(batch_size, -1, -1, -1) - param_dict[ - "right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze( - 0).expand(batch_size, -1, -1, -1) + batch_size, -1, -1, -1 + ) + param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + ) + param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + ) + param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + ) + param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + ) elif param_type == "hand": batch_size = param_dict["right_hand_pose"].shape[0] - param_dict["abs_right_wrist_pose"] = param_dict[ - "right_wrist_pose"].clone() + param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone() dtype = param_dict["right_hand_pose"].dtype device = param_dict["right_hand_pose"].device - x_180_pose = (torch.eye(3, dtype=dtype, - device=device).unsqueeze(0).repeat( - 1, 1, 1)) + x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1)) x_180_pose[0, 2, 2] = -1.0 x_180_pose[0, 1, 1] = -1.0 - param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) - param_dict["shape"] = self.smplx.shape_params.expand( - batch_size, -1) - param_dict["exp"] = self.smplx.expression_params.expand( - batch_size, -1) + param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) + param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1) + param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1) param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) + batch_size, -1, -1, -1 + ) param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) - param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand( - batch_size, -1, -1, -1) - param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze( - 0).expand( - batch_size, -1, -1, - -1)[:, :self.param_list_dict["body_list"]["partbody_pose"]] - param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze( - 0).expand(batch_size, -1, -1, -1) - param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze( - 0).expand(batch_size, -1, -1, -1) + batch_size, -1, -1, -1 + ) + param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) + param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + )[:, :self.param_list_dict["body_list"]["partbody_pose"]] + param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + ) + param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand( + batch_size, -1, -1, -1 + ) elif param_type == "body": # the predcition from the head and hand share regressor is always absolute pose batch_size = param_dict["shape"].shape[0] param_dict["abs_head_pose"] = param_dict["head_pose"].clone() - param_dict["abs_right_wrist_pose"] = param_dict[ - "right_wrist_pose"].clone() - param_dict["abs_left_wrist_pose"] = param_dict[ - "left_wrist_pose"].clone() + param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone() + param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone() # the body-hand share regressor is working for right hand # so we assume body network get the flipped feature for the left hand. then get the parameters # then we need to flip it back to left, which matches the input left hand - param_dict["left_wrist_pose"] = util.flip_pose( - param_dict["left_wrist_pose"]) - param_dict["left_hand_pose"] = util.flip_pose( - param_dict["left_hand_pose"]) + param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"]) + param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"]) else: exit() @@ -508,8 +478,7 @@ class PIXIE(object): Returns: predictions: smplx predictions """ - if "jaw_pose" in param_dict.keys() and len( - param_dict["jaw_pose"].shape) == 2: + if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2: self.convert_pose(param_dict, param_type) elif param_dict["right_wrist_pose"].shape[-1] == 6: self.convert_pose(param_dict, param_type) @@ -532,9 +501,8 @@ class PIXIE(object): # change absolute head&hand pose to relative pose according to rest body pose if param_type == "head" or param_type == "body": param_dict["body_pose"] = self.smplx.pose_abs2rel( - param_dict["global_pose"], - param_dict["body_pose"], - abs_joint="head") + param_dict["global_pose"], param_dict["body_pose"], abs_joint="head" + ) if param_type == "hand" or param_type == "body": param_dict["body_pose"] = self.smplx.pose_abs2rel( param_dict["global_pose"], @@ -550,7 +518,7 @@ class PIXIE(object): if self.cfg.model.check_pose: # check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose) # xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left) - for pose_ind in [14]: # head [15-1, 20-1, 21-1]: + for pose_ind in [14]: # head [15-1, 20-1, 21-1]: curr_pose = param_dict["body_pose"][:, pose_ind] euler_pose = converter._compute_euler_from_matrix(curr_pose) for i, max_angle in enumerate([20, 70, 10]): @@ -560,9 +528,7 @@ class PIXIE(object): min=-max_angle * np.pi / 180, max=max_angle * np.pi / 180, )] = 0.0 - param_dict[ - "body_pose"][:, pose_ind] = converter.batch_euler2matrix( - euler_pose) + param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose) # SMPLX verts, landmarks, joints = self.smplx( @@ -594,8 +560,8 @@ class PIXIE(object): # change the order of face keypoints, to be the same as "standard" 68 keypoints prediction["face_kpt"] = torch.cat( - [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]], - dim=1) + [prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]], dim=1 + ) prediction.update(param_dict) diff --git a/lib/pixielib/utils/array_cropper.py b/lib/pixielib/utils/array_cropper.py index 661146ec42d58c207d00182f162a88f1c594e3df..fbee84b6a6f0f3dcad7fcd6b33bf03faf56be625 100644 --- a/lib/pixielib/utils/array_cropper.py +++ b/lib/pixielib/utils/array_cropper.py @@ -23,15 +23,14 @@ def points2bbox(points, points_scale=None): bottom = np.max(points[:, 1]) size = max(right - left, bottom - top) # + old_size*0.1]) - center = np.array( - [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) + center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) return center, size # translate center def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0): trans_scale = (np.random.rand(2) * 2 - 1) * trans_scale - center = center + trans_scale * bbox_size # 0.5 + center = center + trans_scale * bbox_size # 0.5 scale = np.random.rand() * (scale[1] - scale[0]) + scale[0] size = int(bbox_size * scale) return center, size @@ -48,27 +47,25 @@ def crop_array(image, center, bboxsize, crop_size): tform: 3x3 affine matrix """ # points: top-left, top-right, bottom-right - src_pts = np.array([ - [center[0] - bboxsize / 2, center[1] - bboxsize / 2], - [center[0] + bboxsize / 2, center[1] - bboxsize / 2], - [center[0] + bboxsize / 2, center[1] + bboxsize / 2], - ]) - DST_PTS = np.array([[0, 0], [crop_size - 1, 0], - [crop_size - 1, crop_size - 1]]) + src_pts = np.array( + [ + [center[0] - bboxsize / 2, center[1] - bboxsize / 2], + [center[0] + bboxsize / 2, center[1] - bboxsize / 2], + [center[0] + bboxsize / 2, center[1] + bboxsize / 2], + ] + ) + DST_PTS = np.array([[0, 0], [crop_size - 1, 0], [crop_size - 1, crop_size - 1]]) # estimate transformation between points tform = estimate_transform("similarity", src_pts, DST_PTS) # warp images - cropped_image = warp(image, - tform.inverse, - output_shape=(crop_size, crop_size)) + cropped_image = warp(image, tform.inverse, output_shape=(crop_size, crop_size)) return cropped_image, tform.params.T class Cropper(object): - def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0): self.crop_size = crop_size self.scale = scale @@ -78,11 +75,9 @@ class Cropper(object): # points to bbox center, bbox_size = points2bbox(points, points_scale) # argument bbox. - center, bbox_size = augment_bbox(center, - bbox_size, - scale=self.scale, - trans_scale=self.trans_scale) + center, bbox_size = augment_bbox( + center, bbox_size, scale=self.scale, trans_scale=self.trans_scale + ) # crop - cropped_image, tform = crop_array(image, center, bbox_size, - self.crop_size) + cropped_image, tform = crop_array(image, center, bbox_size, self.crop_size) return cropped_image, tform diff --git a/lib/pixielib/utils/config.py b/lib/pixielib/utils/config.py index 04d8ed809489dac385a1764aae0de0565dfe9d6d..115a38e9c52b7cf025defa4a3d37d9490fc71833 100644 --- a/lib/pixielib/utils/config.py +++ b/lib/pixielib/utils/config.py @@ -8,59 +8,59 @@ import os cfg = CN() -abs_pixie_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..", "..", "..")) +abs_pixie_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) cfg.pixie_dir = abs_pixie_dir cfg.device = "cuda" cfg.device_id = "0" -cfg.pretrained_modelpath = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "pixie_model.tar") +cfg.pretrained_modelpath = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "pixie_model.tar") # smplx parameter settings cfg.params = CN() -cfg.params.body_list = [ - "body_cam", "global_pose", "partbody_pose", "neck_pose" -] +cfg.params.body_list = ["body_cam", "global_pose", "partbody_pose", "neck_pose"] cfg.params.head_list = ["head_cam", "tex", "light"] cfg.params.head_share_list = ["shape", "exp", "head_pose", "jaw_pose"] cfg.params.hand_list = ["hand_cam"] cfg.params.hand_share_list = [ "right_wrist_pose", "right_hand_pose", -] # only for right hand +] # only for right hand # ---------------------------------------------------------------------------- # # Options for Body model # ---------------------------------------------------------------------------- # cfg.model = CN() -cfg.model.topology_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "SMPL_X_template_FLAME_uv.obj") -cfg.model.topology_smplxtex_path = os.path.join(cfg.pixie_dir, - "data/HPS/pixie_data", - "smplx_tex.obj") -cfg.model.topology_smplx_hand_path = os.path.join(cfg.pixie_dir, - "data/HPS/pixie_data", - "smplx_hand.obj") -cfg.model.smplx_model_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "SMPLX_NEUTRAL_2020.npz") -cfg.model.face_mask_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "uv_face_mask.png") -cfg.model.face_eye_mask_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "uv_face_eye_mask.png") -cfg.model.tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "FLAME_albedo_from_BFM.npz") -cfg.model.extra_joint_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "smplx_extra_joints.yaml") -cfg.model.j14_regressor_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "SMPLX_to_J14.pkl") -cfg.model.flame2smplx_cached_path = os.path.join(cfg.pixie_dir, - "data/HPS/pixie_data", - "flame2smplx_tex_1024.npy") -cfg.model.smplx_tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "smplx_tex.png") -cfg.model.mano_ids_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "MANO_SMPLX_vertex_ids.pkl") -cfg.model.flame_ids_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", - "SMPL-X__FLAME_vertex_ids.npy") +cfg.model.topology_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "SMPL_X_template_FLAME_uv.obj" +) +cfg.model.topology_smplxtex_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "smplx_tex.obj" +) +cfg.model.topology_smplx_hand_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "smplx_hand.obj" +) +cfg.model.smplx_model_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "SMPLX_NEUTRAL_2020.npz" +) +cfg.model.face_mask_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "uv_face_mask.png") +cfg.model.face_eye_mask_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "uv_face_eye_mask.png" +) +cfg.model.tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "FLAME_albedo_from_BFM.npz") +cfg.model.extra_joint_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "smplx_extra_joints.yaml" +) +cfg.model.j14_regressor_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "SMPLX_to_J14.pkl" +) +cfg.model.flame2smplx_cached_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "flame2smplx_tex_1024.npy" +) +cfg.model.smplx_tex_path = os.path.join(cfg.pixie_dir, "data/HPS/pixie_data", "smplx_tex.png") +cfg.model.mano_ids_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "MANO_SMPLX_vertex_ids.pkl" +) +cfg.model.flame_ids_path = os.path.join( + cfg.pixie_dir, "data/HPS/pixie_data", "SMPL-X__FLAME_vertex_ids.npy" +) cfg.model.uv_size = 256 cfg.model.n_shape = 200 cfg.model.n_tex = 50 @@ -68,16 +68,16 @@ cfg.model.n_exp = 50 cfg.model.n_body_cam = 3 cfg.model.n_head_cam = 3 cfg.model.n_hand_cam = 3 -cfg.model.tex_type = "BFM" # BFM, FLAME, albedoMM -cfg.model.uvtex_type = "SMPLX" # FLAME or SMPLX -cfg.model.use_tex = False # whether to use flame texture model +cfg.model.tex_type = "BFM" # BFM, FLAME, albedoMM +cfg.model.uvtex_type = "SMPLX" # FLAME or SMPLX +cfg.model.use_tex = False # whether to use flame texture model cfg.model.flame_tex_path = "" # pose cfg.model.n_global_pose = 3 * 2 cfg.model.n_head_pose = 3 * 2 cfg.model.n_neck_pose = 3 * 2 -cfg.model.n_jaw_pose = 3 # euler angle +cfg.model.n_jaw_pose = 3 # euler angle cfg.model.n_body_pose = 21 * 3 * 2 cfg.model.n_partbody_pose = (21 - 4) * 3 * 2 cfg.model.n_left_hand_pose = 15 * 3 * 2 diff --git a/lib/pixielib/utils/renderer.py b/lib/pixielib/utils/renderer.py index d45e9ae0adefdc5dceab89e9bb6e71cca58a3630..eb2dc795e01b3e5c78a4ce848777d6cbc5558401 100755 --- a/lib/pixielib/utils/renderer.py +++ b/lib/pixielib/utils/renderer.py @@ -36,7 +36,7 @@ def set_rasterizer(type="pytorch3d"): f"{curr_dir}/rasterizer/standard_rasterize_cuda_kernel.cu", ], extra_cuda_cflags=["-std=c++14", "-ccbin=$$(which gcc-7)"], - ) # cuda10.2 is not compatible with gcc9. Specify gcc 7 + ) # cuda10.2 is not compatible with gcc9. Specify gcc 7 from standard_rasterize_cuda import standard_rasterize # If JIT does not work, try manually installation first @@ -51,7 +51,6 @@ class StandardRasterizer(nn.Module): can render non-squared image not differentiable """ - def __init__(self, height, width=None): """ use fixed raster_settings for rendering faces @@ -80,15 +79,15 @@ class StandardRasterizer(nn.Module): vertices[..., 2] = vertices[..., 2] * w / 2 f_vs = util.face_vertices(vertices, faces) - standard_rasterize(f_vs, depth_buffer, triangle_buffer, baryw_buffer, - h, w) + standard_rasterize(f_vs, depth_buffer, triangle_buffer, baryw_buffer, h, w) pix_to_face = triangle_buffer[:, :, :, None].long() bary_coords = baryw_buffer[:, :, :, None, :] vismask = (pix_to_face > -1).float() D = attributes.shape[-1] attributes = attributes.clone() - attributes = attributes.view(attributes.shape[0] * attributes.shape[1], - 3, attributes.shape[-1]) + attributes = attributes.view( + attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1] + ) N, H, W, K, _ = bary_coords.shape mask = pix_to_face == -1 pix_to_face = pix_to_face.clone() @@ -96,10 +95,9 @@ class StandardRasterizer(nn.Module): idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) - pixel_vals[mask] = 0 # Replace masked values in output. + pixel_vals[mask] = 0 # Replace masked values in output. pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) - pixel_vals = torch.cat( - [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) + pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) return pixel_vals @@ -110,7 +108,6 @@ class Pytorch3dRasterizer(nn.Module): x,y,z are in image space, normalized can only render squared image now """ - def __init__(self, image_size=224): """ use fixed raster_settings for rendering faces @@ -130,8 +127,7 @@ class Pytorch3dRasterizer(nn.Module): def forward(self, vertices, faces, attributes=None, h=None, w=None): fixed_vertices = vertices.clone() fixed_vertices[..., :2] = -fixed_vertices[..., :2] - meshes_screen = Meshes(verts=fixed_vertices.float(), - faces=faces.long()) + meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long()) raster_settings = self.raster_settings pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( meshes_screen, @@ -145,8 +141,9 @@ class Pytorch3dRasterizer(nn.Module): vismask = (pix_to_face > -1).float() D = attributes.shape[-1] attributes = attributes.clone() - attributes = attributes.view(attributes.shape[0] * attributes.shape[1], - 3, attributes.shape[-1]) + attributes = attributes.view( + attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1] + ) N, H, W, K, _ = bary_coords.shape mask = pix_to_face == -1 pix_to_face = pix_to_face.clone() @@ -154,20 +151,14 @@ class Pytorch3dRasterizer(nn.Module): idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) - pixel_vals[mask] = 0 # Replace masked values in output. + pixel_vals[mask] = 0 # Replace masked values in output. pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) - pixel_vals = torch.cat( - [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) + pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) return pixel_vals class SRenderY(nn.Module): - - def __init__(self, - image_size, - obj_filename, - uv_size=256, - rasterizer_type="standard"): + def __init__(self, image_size, obj_filename, uv_size=256, rasterizer_type="standard"): super(SRenderY, self).__init__() self.image_size = image_size self.uv_size = uv_size @@ -176,8 +167,8 @@ class SRenderY(nn.Module): self.rasterizer = Pytorch3dRasterizer(image_size) self.uv_rasterizer = Pytorch3dRasterizer(uv_size) verts, faces, aux = load_obj(obj_filename) - uvcoords = aux.verts_uvs[None, ...] # (N, V, 2) - uvfaces = faces.textures_idx[None, ...] # (N, F, 3) + uvcoords = aux.verts_uvs[None, ...] # (N, V, 2) + uvfaces = faces.textures_idx[None, ...] # (N, F, 3) faces = faces.verts_idx[None, ...] elif rasterizer_type == "standard": self.rasterizer = StandardRasterizer(image_size) @@ -192,15 +183,12 @@ class SRenderY(nn.Module): # faces dense_triangles = util.generate_triangles(uv_size, uv_size) - self.register_buffer( - "dense_faces", - torch.from_numpy(dense_triangles).long()[None, :, :]) + self.register_buffer("dense_faces", torch.from_numpy(dense_triangles).long()[None, :, :]) self.register_buffer("faces", faces) self.register_buffer("raw_uvcoords", uvcoords) # uv coords - uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0.0 + 1.0], - -1) # [bz, ntv, 3] + uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0.0 + 1.0], -1) # [bz, ntv, 3] uvcoords = uvcoords * 2 - 1 uvcoords[..., 1] = -uvcoords[..., 1] face_uvcoords = util.face_vertices(uvcoords, uvfaces) @@ -209,26 +197,29 @@ class SRenderY(nn.Module): self.register_buffer("face_uvcoords", face_uvcoords) # shape colors, for rendering shape overlay - colors = (torch.tensor([180, 180, 180])[None, None, :].repeat( - 1, - faces.max() + 1, 1).float() / 255.0) + colors = ( + torch.tensor([180, 180, 180])[None, None, :].repeat(1, + faces.max() + 1, 1).float() / 255.0 + ) face_colors = util.face_vertices(colors, faces) self.register_buffer("vertex_colors", colors) self.register_buffer("face_colors", face_colors) # SH factors for lighting pi = np.pi - constant_factor = torch.tensor([ - 1 / np.sqrt(4 * pi), - ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), - ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), - ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), - (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), - (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), - (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), - (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))), - (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))), - ]).float() + constant_factor = torch.tensor( + [ + 1 / np.sqrt(4 * pi), + ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), + ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), + ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))), + (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))), + (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))), + ] + ).float() self.register_buffer("constant_factor", constant_factor) def forward( @@ -256,23 +247,24 @@ class SRenderY(nn.Module): batch_size = vertices.shape[0] # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100) transformed_vertices = transformed_vertices.clone() - transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] - - transformed_vertices[:, :, 2].min()) - transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] / - transformed_vertices[:, :, 2].max()) + transformed_vertices[:, :, 2] = ( + transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min() + ) + transformed_vertices[:, :, 2] = ( + transformed_vertices[:, :, 2] / transformed_vertices[:, :, 2].max() + ) transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10 # attributes - face_vertices = util.face_vertices( - vertices, self.faces.expand(batch_size, -1, -1)) - normals = util.vertex_normals(vertices, - self.faces.expand(batch_size, -1, -1)) - face_normals = util.face_vertices( - normals, self.faces.expand(batch_size, -1, -1)) + face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) + normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)) + face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) transformed_normals = util.vertex_normals( - transformed_vertices, self.faces.expand(batch_size, -1, -1)) + transformed_vertices, self.faces.expand(batch_size, -1, -1) + ) transformed_face_normals = util.face_vertices( - transformed_normals, self.faces.expand(batch_size, -1, -1)) + transformed_normals, self.faces.expand(batch_size, -1, -1) + ) attributes = torch.cat( [ self.face_uvcoords.expand(batch_size, -1, -1, -1), @@ -314,38 +306,32 @@ class SRenderY(nn.Module): if light_type == "point": vertice_images = rendering[:, 6:9, :, :].detach() shading = self.add_pointlight( - vertice_images.permute(0, 2, 3, - 1).reshape([batch_size, -1, 3]), - normal_images.permute(0, 2, 3, - 1).reshape([batch_size, -1, 3]), + vertice_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), + normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights, ) - shading_images = shading.reshape([ - batch_size, albedo_images.shape[2], - albedo_images.shape[3], 3 - ]).permute(0, 3, 1, 2) + shading_images = shading.reshape( + [batch_size, albedo_images.shape[2], albedo_images.shape[3], 3] + ).permute(0, 3, 1, 2) else: shading = self.add_directionlight( - normal_images.permute(0, 2, 3, - 1).reshape([batch_size, -1, 3]), + normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights, ) - shading_images = shading.reshape([ - batch_size, albedo_images.shape[2], - albedo_images.shape[3], 3 - ]).permute(0, 3, 1, 2) + shading_images = shading.reshape( + [batch_size, albedo_images.shape[2], albedo_images.shape[3], 3] + ).permute(0, 3, 1, 2) images = albedo_images * shading_images else: images = albedo_images shading_images = images.detach() * 0.0 if background is None: - images = images * alpha_images + torch.ones_like(images).to( - vertices.device) * (1 - alpha_images) + images = images * alpha_images + torch.ones_like(images).to(vertices.device + ) * (1 - alpha_images) else: # background = F.interpolate(background, [self.image_size, self.image_size]) - images = images * alpha_images + background.contiguous() * ( - 1 - alpha_images) + images = images * alpha_images + background.contiguous() * (1 - alpha_images) outputs = { "images": images, @@ -379,11 +365,10 @@ class SRenderY(nn.Module): 3 * (N[:, 2]**2) - 1, ], 1, - ) # [bz, 9, h, w] + ) # [bz, 9, h, w] sh = sh * self.constant_factor[None, :, None, None] # [bz, 9, 3, h, w] - shading = torch.sum( - sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1) + shading = torch.sum(sh_coeff[:, :, :, None, None] * sh[:, :, None, :, :], 1) return shading def add_pointlight(self, vertices, normals, lights): @@ -395,14 +380,12 @@ class SRenderY(nn.Module): """ light_positions = lights[:, :, :3] light_intensities = lights[:, :, 3:] - directions_to_lights = F.normalize(light_positions[:, :, None, :] - - vertices[:, None, :, :], - dim=3) + directions_to_lights = F.normalize( + light_positions[:, :, None, :] - vertices[:, None, :, :], dim=3 + ) # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) - normals_dot_lights = (normals[:, None, :, :] * - directions_to_lights).sum(dim=3) - shading = normals_dot_lights[:, :, :, - None] * light_intensities[:, :, None, :] + normals_dot_lights = (normals[:, None, :, :] * directions_to_lights).sum(dim=3) + shading = normals_dot_lights[:, :, :, None] * light_intensities[:, :, None, :] return shading.mean(1) def add_directionlight(self, normals, lights): @@ -415,16 +398,14 @@ class SRenderY(nn.Module): light_direction = lights[:, :, :3] light_intensities = lights[:, :, 3:] directions_to_lights = F.normalize( - light_direction[:, :, None, :].expand(-1, -1, normals.shape[1], - -1), - dim=3) + light_direction[:, :, None, :].expand(-1, -1, normals.shape[1], -1), dim=3 + ) # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) # normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) normals_dot_lights = torch.clamp( - (normals[:, None, :, :] * directions_to_lights).sum(dim=3), 0.0, - 1.0) - shading = normals_dot_lights[:, :, :, - None] * light_intensities[:, :, None, :] + (normals[:, None, :, :] * directions_to_lights).sum(dim=3), 0.0, 1.0 + ) + shading = normals_dot_lights[:, :, :, None] * light_intensities[:, :, None, :] return shading.mean(1) def render_shape( @@ -445,36 +426,38 @@ class SRenderY(nn.Module): """ batch_size = vertices.shape[0] if lights is None: - light_positions = (torch.tensor([ - [-5, 5, -5], - [5, 5, -5], - [-5, -5, -5], - [5, -5, -5], - [0, 0, -5], - ])[None, :, :].expand(batch_size, -1, -1).float()) + light_positions = ( + torch.tensor([ + [-5, 5, -5], + [5, 5, -5], + [-5, -5, -5], + [5, -5, -5], + [0, 0, -5], + ])[None, :, :].expand(batch_size, -1, -1).float() + ) light_intensities = torch.ones_like(light_positions).float() * 1.7 - lights = torch.cat((light_positions, light_intensities), - 2).to(vertices.device) + lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device) # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100) transformed_vertices = transformed_vertices.clone() - transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] - - transformed_vertices[:, :, 2].min()) - transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] / - transformed_vertices[:, :, 2].max()) + transformed_vertices[:, :, 2] = ( + transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min() + ) + transformed_vertices[:, :, 2] = ( + transformed_vertices[:, :, 2] / transformed_vertices[:, :, 2].max() + ) transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10 # Attributes - face_vertices = util.face_vertices( - vertices, self.faces.expand(batch_size, -1, -1)) - normals = util.vertex_normals(vertices, - self.faces.expand(batch_size, -1, -1)) - face_normals = util.face_vertices( - normals, self.faces.expand(batch_size, -1, -1)) + face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) + normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)) + face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) transformed_normals = util.vertex_normals( - transformed_vertices, self.faces.expand(batch_size, -1, -1)) + transformed_vertices, self.faces.expand(batch_size, -1, -1) + ) transformed_face_normals = util.face_vertices( - transformed_normals, self.faces.expand(batch_size, -1, -1)) + transformed_normals, self.faces.expand(batch_size, -1, -1) + ) if colors is None: colors = self.face_colors.expand(batch_size, -1, -1, -1) attributes = torch.cat( @@ -513,22 +496,22 @@ class SRenderY(nn.Module): if uv_detail_normals is not None: uvcoords_images = rendering[:, 12:15, :, :] grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] - detail_normal_images = F.grid_sample(uv_detail_normals, - grid, - align_corners=False) + detail_normal_images = F.grid_sample(uv_detail_normals, grid, align_corners=False) normal_images = detail_normal_images shading = self.add_directionlight( - normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), - lights) - shading_images = (shading.reshape( - [batch_size, albedo_images.shape[2], albedo_images.shape[3], - 3]).permute(0, 3, 1, 2).contiguous()) + normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights + ) + shading_images = ( + shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], + 3]).permute(0, 3, 1, 2).contiguous() + ) shaded_images = albedo_images * shading_images if background is None: - shape_images = shaded_images * alpha_images + torch.ones_like( - shaded_images).to(vertices.device) * (1 - alpha_images) + shape_images = shaded_images * alpha_images + torch.ones_like(shaded_images).to( + vertices.device + ) * (1 - alpha_images) else: # background = F.interpolate(background, [self.image_size, self.image_size]) shape_images = shaded_images * alpha_images + background.contiguous( @@ -548,18 +531,18 @@ class SRenderY(nn.Module): transformed_vertices = transformed_vertices.clone() batch_size = transformed_vertices.shape[0] - transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] - - transformed_vertices[:, :, 2].min()) + transformed_vertices[:, :, 2] = ( + transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min() + ) z = -transformed_vertices[:, :, 2:].repeat(1, 1, 3) z = z - z.min() z = z / z.max() # Attributes - attributes = util.face_vertices(z, - self.faces.expand(batch_size, -1, -1)) + attributes = util.face_vertices(z, self.faces.expand(batch_size, -1, -1)) # rasterize - rendering = self.rasterizer(transformed_vertices, - self.faces.expand(batch_size, -1, -1), - attributes) + rendering = self.rasterizer( + transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes + ) #### alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() @@ -574,14 +557,15 @@ class SRenderY(nn.Module): transformed_vertices = transformed_vertices.clone() batch_size = colors.shape[0] # normalize z to 10-90 for raterization (in pytorch3d, near far: 0-100) - transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] - - transformed_vertices[:, :, 2].min()) - transformed_vertices[:, :, 2] = (transformed_vertices[:, :, 2] / - transformed_vertices[:, :, 2].max()) + transformed_vertices[:, :, 2] = ( + transformed_vertices[:, :, 2] - transformed_vertices[:, :, 2].min() + ) + transformed_vertices[:, :, 2] = ( + transformed_vertices[:, :, 2] / transformed_vertices[:, :, 2].max() + ) transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] * 80 + 10 # Attributes - attributes = util.face_vertices(colors, - self.faces.expand(batch_size, -1, -1)) + attributes = util.face_vertices(colors, self.faces.expand(batch_size, -1, -1)) # rasterize rendering = self.rasterizer( transformed_vertices, @@ -602,8 +586,7 @@ class SRenderY(nn.Module): uv_vertices: [bz, 3, h, w] """ batch_size = vertices.shape[0] - face_vertices = util.face_vertices( - vertices, self.faces.expand(batch_size, -1, -1)) + face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) uv_vertices = self.uv_rasterizer( self.uvcoords.expand(batch_size, -1, -1), self.uvfaces.expand(batch_size, -1, -1), diff --git a/lib/pixielib/utils/rotation_converter.py b/lib/pixielib/utils/rotation_converter.py index 257e4eb2c12c242657d0275825717c48c56b5948..f8057cab4e0f84d035a0b8f964823bd61e91dae4 100644 --- a/lib/pixielib/utils/rotation_converter.py +++ b/lib/pixielib/utils/rotation_converter.py @@ -27,8 +27,7 @@ def rad2deg(tensor): >>> output = tgm.rad2deg(input) """ if not torch.is_tensor(tensor): - raise TypeError("Input type is not a torch.Tensor. Got {}".format( - type(tensor))) + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(tensor))) return 180.0 * tensor / pi.to(tensor.device).type(tensor.dtype) @@ -50,8 +49,7 @@ def deg2rad(tensor): >>> output = tgm.deg2rad(input) """ if not torch.is_tensor(tensor): - raise TypeError("Input type is not a torch.Tensor. Got {}".format( - type(tensor))) + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(tensor))) return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.0 @@ -102,13 +100,12 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 """ if not torch.is_tensor(rotation_matrix): - raise TypeError("Input type is not a torch.Tensor. Got {}".format( - type(rotation_matrix))) + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))) if len(rotation_matrix.shape) > 3: raise ValueError( - "Input size must be a three dimensional tensor. Got {}".format( - rotation_matrix.shape)) + "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape) + ) # if not rotation_matrix.shape[-2:] == (3, 4): # raise ValueError( # "Input size must be a N x 3 x 4 tensor. Got {}".format( @@ -179,9 +176,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): mask_c3 = mask_c3.view(-1, 1).type_as(q3) q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 - q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + - t2_rep * mask_c2 # noqa - + t3_rep * mask_c3) # noqa + q /= torch.sqrt( + t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa + + t3_rep * mask_c3 + ) # noqa q *= 0.5 return q @@ -206,13 +204,12 @@ def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor: >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3 """ if not torch.is_tensor(angle_axis): - raise TypeError("Input type is not a torch.Tensor. Got {}".format( - type(angle_axis))) + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(angle_axis))) if not angle_axis.shape[-1] == 3: raise ValueError( - "Input must be a tensor of shape Nx3 or 3. Got {}".format( - angle_axis.shape)) + "Input must be a tensor of shape Nx3 or 3. Got {}".format(angle_axis.shape) + ) # unpack input and compute conversion a0: torch.Tensor = angle_axis[..., 0:1] a1: torch.Tensor = angle_axis[..., 1:2] @@ -249,9 +246,7 @@ def quaternion_to_rotation_matrix(quat): """ norm_quat = quat norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) - w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, - 2], norm_quat[:, - 3] + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] B = quat.size(0) @@ -296,13 +291,12 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor): >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 """ if not torch.is_tensor(quaternion): - raise TypeError("Input type is not a torch.Tensor. Got {}".format( - type(quaternion))) + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) if not quaternion.shape[-1] == 4: raise ValueError( - "Input must be a tensor of shape Nx4 or 4. Got {}".format( - quaternion.shape)) + "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape) + ) # unpack input and compute conversion q1: torch.Tensor = quaternion[..., 1] q2: torch.Tensor = quaternion[..., 2] @@ -318,12 +312,10 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor): ) k_pos: torch.Tensor = two_theta / sin_theta - k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to( - quaternion.device) + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device) k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) - angle_axis: torch.Tensor = torch.zeros_like(quaternion).to( - quaternion.device)[..., :3] + angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3] angle_axis[..., 0] += q1 * k angle_axis[..., 1] += q2 * k angle_axis[..., 2] += q3 * k @@ -408,10 +400,10 @@ def _compute_euler_from_matrix(dcm, seq="xyz", extrinsic=False): # 5b safe_mask = torch.logical_and(safe1, safe2) - angles[safe_mask, 0] = torch.atan2(dcm_transformed[safe_mask, 0, 2], - -dcm_transformed[safe_mask, 1, 2]) - angles[safe_mask, 2] = torch.atan2(dcm_transformed[safe_mask, 2, 0], - dcm_transformed[safe_mask, 2, 1]) + angles[safe_mask, + 0] = torch.atan2(dcm_transformed[safe_mask, 0, 2], -dcm_transformed[safe_mask, 1, 2]) + angles[safe_mask, + 2] = torch.atan2(dcm_transformed[safe_mask, 2, 0], dcm_transformed[safe_mask, 2, 1]) if extrinsic: # For extrinsic, set first angle to zero so that after reversal we # ensure that third angle is zero @@ -448,8 +440,7 @@ def _compute_euler_from_matrix(dcm, seq="xyz", extrinsic=False): adjust_mask = torch.logical_or(angles[:, 1] < 0, angles[:, 1] > np.pi) else: # lambda = + or - pi/2, so we can ensure angle2 -> [-pi/2, pi/2] - adjust_mask = torch.logical_or(angles[:, 1] < -np.pi / 2, - angles[:, 1] > np.pi / 2) + adjust_mask = torch.logical_or(angles[:, 1] < -np.pi / 2, angles[:, 1] > np.pi / 2) # Dont adjust gimbal locked angle sequences adjust_mask = torch.logical_and(adjust_mask, safe_mask) @@ -463,8 +454,10 @@ def _compute_euler_from_matrix(dcm, seq="xyz", extrinsic=False): # Step 8 if not torch.all(safe_mask): - print("Gimbal lock detected. Setting third angle to zero since" - "it is not possible to uniquely determine all angles.") + print( + "Gimbal lock detected. Setting third angle to zero since" + "it is not possible to uniquely determine all angles." + ) # Reverse role of extrinsic and intrinsic rotations, but let third angle be # zero for gimbal locked cases @@ -497,8 +490,7 @@ def batch_matrix2euler(rot_mats): # Careful for extreme cases of eular angles like [0.0, pi, 0.0] # only y biw # TODO: add x, z - sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + - rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) return torch.atan2(-rot_mats[:, 2, 0], sy) @@ -550,8 +542,7 @@ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) - K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], - dim=1).view((batch_size, 3, 3)) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) @@ -571,9 +562,7 @@ def batch_cont2matrix(module_input): # Normalize the first vector b1 = F.normalize(reshaped_input[:, :, 0].clone(), dim=1) - dot_prod = torch.sum(b1 * reshaped_input[:, :, 1].clone(), - dim=1, - keepdim=True) + dot_prod = torch.sum(b1 * reshaped_input[:, :, 1].clone(), dim=1, keepdim=True) # Compute the second vector by finding the orthogonal complement to it b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=1) # Finish building the basis by taking the cross product diff --git a/lib/pixielib/utils/tensor_cropper.py b/lib/pixielib/utils/tensor_cropper.py index 21520901c79d4b690ab3747df6a7c820bf5de951..c486f7709ad9216080102ee275f7165d276eb0ce 100644 --- a/lib/pixielib/utils/tensor_cropper.py +++ b/lib/pixielib/utils/tensor_cropper.py @@ -34,21 +34,14 @@ def points2bbox(points, points_scale=None): def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.0): batch_size = center.shape[0] - trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2.0 - - 1.0) * trans_scale - center = center + trans_scale * bbox_size # 0.5 - scale = (torch.rand([batch_size, 1], device=center.device) * - (scale[1] - scale[0]) + scale[0]) + trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2.0 - 1.0) * trans_scale + center = center + trans_scale * bbox_size # 0.5 + scale = (torch.rand([batch_size, 1], device=center.device) * (scale[1] - scale[0]) + scale[0]) size = bbox_size * scale return center, size -def crop_tensor(image, - center, - bbox_size, - crop_size, - interpolation="bilinear", - align_corners=False): +def crop_tensor(image, center, bbox_size, crop_size, interpolation="bilinear", align_corners=False): """for batch image Args: image (torch.Tensor): the reference tensor of shape BXHxWXC. @@ -66,11 +59,12 @@ def crop_tensor(image, device = image.device batch_size = image.shape[0] # points: top-left, top-right, bottom-right, bottom-left - src_pts = (torch.zeros([4, 2], dtype=dtype, - device=device).unsqueeze(0).expand( - batch_size, -1, -1).contiguous()) + src_pts = ( + torch.zeros([4, 2], dtype=dtype, device=device).unsqueeze(0).expand(batch_size, -1, + -1).contiguous() + ) - src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1) + src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1) src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5 src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5 src_pts[:, 2, :] = center + bbox_size * 0.5 @@ -107,7 +101,6 @@ def crop_tensor(image, class Cropper(object): - def __init__(self, crop_size, scale=[1, 1], trans_scale=0.0): self.crop_size = crop_size self.scale = scale @@ -116,21 +109,14 @@ class Cropper(object): def crop(self, image, points, points_scale=None): # points to bbox center, bbox_size = points2bbox(points.clone(), points_scale) - # argument bbox. TODO: add rotation? - center, bbox_size = augment_bbox(center, - bbox_size, - scale=self.scale, - trans_scale=self.trans_scale) + center, bbox_size = augment_bbox( + center, bbox_size, scale=self.scale, trans_scale=self.trans_scale + ) # crop - cropped_image, tform = crop_tensor(image, center, bbox_size, - self.crop_size) + cropped_image, tform = crop_tensor(image, center, bbox_size, self.crop_size) return cropped_image, tform - def transform_points(self, - points, - tform, - points_scale=None, - normalize=True): + def transform_points(self, points, tform, points_scale=None, normalize=True): points_2d = points[:, :, :2] #'input points must use original range' @@ -153,11 +139,9 @@ class Cropper(object): ), tform, ) - trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], - dim=-1) + trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], dim=-1) if normalize: - trans_points[:, :, : - 2] = trans_points[:, :, :2] / self.crop_size * 2 - 1 + trans_points[:, :, :2] = trans_points[:, :, :2] / self.crop_size * 2 - 1 return trans_points @@ -174,14 +158,11 @@ def transform_points(points, tform, points_scale=None): torch.cat( [ points_2d, - torch.ones([batch_size, n_points, 1], - device=points.device, - dtype=points.dtype), + torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype), ], dim=-1, ), tform, ) - trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], - dim=-1) + trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], dim=-1) return trans_points diff --git a/lib/pixielib/utils/util.py b/lib/pixielib/utils/util.py index 1e8ec90836ce198186b0842d3eadee1b02318a2d..566eda3a6e6ddf7f236bf4e20bf7220b39981ce3 100755 --- a/lib/pixielib/utils/util.py +++ b/lib/pixielib/utils/util.py @@ -46,8 +46,7 @@ def face_vertices(vertices, faces): bs, nv = vertices.shape[:2] bs, nf = faces.shape[:2] device = vertices.device - faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * - nv)[:, None, None] + faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] vertices = vertices.reshape((bs * nv, 3)) # pytorch only supports long and byte tensors for indexing return vertices[faces.long()] @@ -71,9 +70,8 @@ def vertex_normals(vertices, faces): normals = torch.zeros(bs * nv, 3).to(device) faces = ( - faces + - (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] - ) # expanded faces + faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] + ) # expanded faces vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()] faces = faces.reshape(-1, 3) @@ -145,12 +143,10 @@ def flip_pose(pose_vector, pose_format="rot-mat"): # -------------------------------------- image processing # ref: https://torchgeometry.readthedocs.io/en/latest/_modules/kornia/filters def gaussian(window_size, sigma): - def gauss_fcn(x): return -((x - window_size // 2)**2) / float(2 * sigma**2) - gauss = torch.stack( - [torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)]) + gauss = torch.stack([torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)]) return gauss / gauss.sum() @@ -175,10 +171,11 @@ def get_gaussian_kernel(kernel_size: int, sigma: float): >>> kornia.image.get_gaussian_kernel(5, 1.5) tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) """ - if not isinstance(kernel_size, - int) or kernel_size % 2 == 0 or kernel_size <= 0: - raise TypeError("kernel_size must be an odd positive integer. " - "Got {}".format(kernel_size)) + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: + raise TypeError( + "kernel_size must be an odd positive integer. " + "Got {}".format(kernel_size) + ) window_1d = gaussian(kernel_size, sigma) return window_1d @@ -211,18 +208,14 @@ def get_gaussian_kernel2d(kernel_size, sigma): [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]]) """ if not isinstance(kernel_size, tuple) or len(kernel_size) != 2: - raise TypeError( - "kernel_size must be a tuple of length two. Got {}".format( - kernel_size)) + raise TypeError("kernel_size must be a tuple of length two. Got {}".format(kernel_size)) if not isinstance(sigma, tuple) or len(sigma) != 2: - raise TypeError( - "sigma must be a tuple of length two. Got {}".format(sigma)) + raise TypeError("sigma must be a tuple of length two. Got {}".format(sigma)) ksize_x, ksize_y = kernel_size sigma_x, sigma_y = sigma kernel_x = get_gaussian_kernel(ksize_x, sigma_x) kernel_y = get_gaussian_kernel(ksize_y, sigma_y) - kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), - kernel_y.unsqueeze(-1).t()) + kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t()) return kernel_2d @@ -283,10 +276,8 @@ def get_laplacian_kernel2d(kernel_size: int): [ 1., 1., 1., 1., 1.]]) """ - if not isinstance(kernel_size, - int) or kernel_size % 2 == 0 or kernel_size <= 0: - raise TypeError("ksize must be an odd positive integer. Got {}".format( - kernel_size)) + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: + raise TypeError("ksize must be an odd positive integer. Got {}".format(kernel_size)) kernel = torch.ones((kernel_size, kernel_size)) mid = kernel_size // 2 @@ -309,7 +300,6 @@ def laplacian(x): def copy_state_dict(cur_state_dict, pre_state_dict, prefix="", load_name=None): - def _get_params(key): key = prefix + key if key in pre_state_dict: @@ -353,7 +343,7 @@ def remove_module(state_dict): # create new OrderedDict that does not contain `module.` new_state_dict = OrderedDict() for k, v in state_dict.items(): - name = k[7:] # remove `module.` + name = k[7:] # remove `module.` new_state_dict[name] = v return new_state_dict @@ -433,24 +423,24 @@ def write_obj( # write vertices if colors is None: for i in range(vertices.shape[0]): - f.write("v {} {} {}\n".format(vertices[i, 0], vertices[i, 1], - vertices[i, 2])) + f.write("v {} {} {}\n".format(vertices[i, 0], vertices[i, 1], vertices[i, 2])) else: for i in range(vertices.shape[0]): - f.write("v {} {} {} {} {} {}\n".format( - vertices[i, 0], - vertices[i, 1], - vertices[i, 2], - colors[i, 0], - colors[i, 1], - colors[i, 2], - )) + f.write( + "v {} {} {} {} {} {}\n".format( + vertices[i, 0], + vertices[i, 1], + vertices[i, 2], + colors[i, 0], + colors[i, 1], + colors[i, 2], + ) + ) # write uv coords if texture is None: for i in range(faces.shape[0]): - f.write("f {} {} {}\n".format(faces[i, 0], faces[i, 1], - faces[i, 2])) + f.write("f {} {} {}\n".format(faces[i, 0], faces[i, 1], faces[i, 2])) else: for i in range(uvcoords.shape[0]): f.write("vt {} {}\n".format(uvcoords[i, 0], uvcoords[i, 1])) @@ -458,37 +448,37 @@ def write_obj( # write f: ver ind/ uv ind uvfaces = uvfaces + 1 for i in range(faces.shape[0]): - f.write("f {}/{} {}/{} {}/{}\n".format( - faces[i, 0], - uvfaces[i, 0], - faces[i, 1], - uvfaces[i, 1], - faces[i, 2], - uvfaces[i, 2], - )) + f.write( + "f {}/{} {}/{} {}/{}\n".format( + faces[i, 0], + uvfaces[i, 0], + faces[i, 1], + uvfaces[i, 1], + faces[i, 2], + uvfaces[i, 2], + ) + ) # write mtl with open(mtl_name, "w") as f: f.write("newmtl %s\n" % material_name) - s = "map_Kd {}\n".format( - os.path.basename(texture_name)) # map to image + s = "map_Kd {}\n".format(os.path.basename(texture_name)) # map to image f.write(s) if normal_map is not None: if torch.is_tensor(normal_map): - normal_map = normal_map.detach().cpu().numpy().squeeze( - ) + normal_map = normal_map.detach().cpu().numpy().squeeze() normal_map = np.transpose(normal_map, (1, 2, 0)) name, _ = os.path.splitext(obj_name) normal_name = f"{name}_normals.png" f.write(f"disp {normal_name}") - out_normal_map = normal_map / (np.linalg.norm( - normal_map, axis=-1, keepdims=True) + 1e-9) + out_normal_map = normal_map / ( + np.linalg.norm(normal_map, axis=-1, keepdims=True) + 1e-9 + ) out_normal_map = (out_normal_map + 1) * 0.5 - cv2.imwrite(normal_name, (out_normal_map * 255).astype( - np.uint8)[:, :, ::-1]) + cv2.imwrite(normal_name, (out_normal_map * 255).astype(np.uint8)[:, :, ::-1]) cv2.imwrite(texture_name, texture) @@ -523,20 +513,20 @@ def load_obj(obj_filename): for line in lines: tokens = line.strip().split() - if line.startswith("v "): # Line is a vertex. + if line.startswith("v "): # Line is a vertex. vert = [float(x) for x in tokens[1:4]] if len(vert) != 3: msg = "Vertex %s does not have 3 values. Line: %s" raise ValueError(msg % (str(vert), str(line))) verts.append(vert) - elif line.startswith("vt "): # Line is a texture. + elif line.startswith("vt "): # Line is a texture. tx = [float(x) for x in tokens[1:3]] if len(tx) != 2: raise ValueError( - "Texture %s does not have 2 values. Line: %s" % - (str(tx), str(line))) + "Texture %s does not have 2 values. Line: %s" % (str(tx), str(line)) + ) uvcoords.append(tx) - elif line.startswith("f "): # Line is a face. + elif line.startswith("f "): # Line is a face. # Update face properties info. face = tokens[1:] face_list = [f.split("/") for f in face] @@ -558,12 +548,7 @@ def load_obj(obj_filename): # ---------------------------------- visualization -def draw_rectangle(img, - bbox, - bbox_color=(255, 255, 255), - thickness=3, - is_opaque=False, - alpha=0.5): +def draw_rectangle(img, bbox, bbox_color=(255, 255, 255), thickness=3, is_opaque=False, alpha=0.5): """Draws the rectangle around the object borrowed from: https://bbox-visualizer.readthedocs.io/en/latest/_modules/bbox_visualizer/bbox_visualizer.html Parameters @@ -589,13 +574,11 @@ def draw_rectangle(img, output = img.copy() if not is_opaque: - cv2.rectangle(output, (bbox[0], bbox[1]), (bbox[2], bbox[3]), - bbox_color, thickness) + cv2.rectangle(output, (bbox[0], bbox[1]), (bbox[2], bbox[3]), bbox_color, thickness) else: overlay = img.copy() - cv2.rectangle(overlay, (bbox[0], bbox[1]), (bbox[2], bbox[3]), - bbox_color, -1) + cv2.rectangle(overlay, (bbox[0], bbox[1]), (bbox[2], bbox[3]), bbox_color, -1) # cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output) return output @@ -607,9 +590,9 @@ def plot_bbox(image, bbox): image: the input image bbox: [left, top, right, bottom] """ - image = cv2.rectangle(image.copy(), (bbox[1], bbox[0]), (bbox[3], bbox[2]), - [0, 255, 0], - thickness=3) + image = cv2.rectangle( + image.copy(), (bbox[1], bbox[0]), (bbox[3], bbox[2]), [0, 255, 0], thickness=3 + ) # image = draw_rectangle(image, bbox, bbox_color=[0,255,0]) return image @@ -644,8 +627,7 @@ def plot_kpts(image, kpts, color="r"): if i in end_list: continue ed = kpts[i + 1, :2] - image = cv2.line(image, (st[0], st[1]), (ed[0], ed[1]), - (255, 255, 255), 1) + image = cv2.line(image, (st[0], st[1]), (ed[0], ed[1]), (255, 255, 255), 1) return image @@ -674,11 +656,7 @@ def plot_verts(image, kpts, color="r"): return image -def tensor_vis_landmarks(images, - landmarks, - gt_landmarks=None, - color="g", - isScale=True): +def tensor_vis_landmarks(images, landmarks, gt_landmarks=None, color="g", isScale=True): # visualize landmarks vis_landmarks = [] images = images.cpu().numpy() @@ -690,8 +668,7 @@ def tensor_vis_landmarks(images, image = image.transpose(1, 2, 0)[:, :, [2, 1, 0]].copy() image = image * 255 if isScale: - predicted_landmark = (predicted_landmarks[i] * image.shape[0] / 2 + - image.shape[0] / 2) + predicted_landmark = (predicted_landmarks[i] * image.shape[0] / 2 + image.shape[0] / 2) else: predicted_landmark = predicted_landmarks[i] if predicted_landmark.shape[0] == 68: @@ -699,8 +676,7 @@ def tensor_vis_landmarks(images, if gt_landmarks is not None: image_landmarks = plot_verts( image_landmarks, - gt_landmarks_np[i] * image.shape[0] / 2 + - image.shape[0] / 2, + gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, "r", ) else: @@ -708,14 +684,13 @@ def tensor_vis_landmarks(images, if gt_landmarks is not None: image_landmarks = plot_verts( image_landmarks, - gt_landmarks_np[i] * image.shape[0] / 2 + - image.shape[0] / 2, + gt_landmarks_np[i] * image.shape[0] / 2 + image.shape[0] / 2, "r", ) vis_landmarks.append(image_landmarks) vis_landmarks = np.stack(vis_landmarks) - vis_landmarks = (torch.from_numpy( - vis_landmarks[:, :, :, [2, 1, 0]].transpose(0, 3, 1, 2)) / 255.0 - ) # , dtype=torch.float32) + vis_landmarks = ( + torch.from_numpy(vis_landmarks[:, :, :, [2, 1, 0]].transpose(0, 3, 1, 2)) / 255.0 + ) # , dtype=torch.float32) return vis_landmarks diff --git a/lib/pymafx/core/cfgs.py b/lib/pymafx/core/cfgs.py index 580643a7cb2ad7caaec29223b372669b17992926..c970c6c0caafe7a4c2f3abbb311adcd0cef42b94 100644 --- a/lib/pymafx/core/cfgs.py +++ b/lib/pymafx/core/cfgs.py @@ -67,6 +67,7 @@ def get_cfg_defaults(): # return cfg.clone() return cfg + def update_cfg(cfg_file): # cfg = get_cfg_defaults() cfg.merge_from_file(cfg_file) @@ -86,6 +87,7 @@ def parse_args(args): return cfg + def parse_args_extend(args): if args.resume: if not os.path.exists(args.log_dir): diff --git a/lib/pymafx/core/constants.py b/lib/pymafx/core/constants.py index 47fd25e31d1ebfb11cd19a2fa8d8c4e61c79cc5d..5354a289f892a764a16221b469fc49794ff54127 100644 --- a/lib/pymafx/core/constants.py +++ b/lib/pymafx/core/constants.py @@ -43,23 +43,23 @@ SPIN_JOINT_NAMES = [ # 24 Ground Truth joints (superset of joints from different datasets) 'Right Ankle', 'Right Knee', - 'Right Hip', # 2 + 'Right Hip', # 2 'Left Hip', - 'Left Knee', # 4 + 'Left Knee', # 4 'Left Ankle', - 'Right Wrist', # 6 + 'Right Wrist', # 6 'Right Elbow', - 'Right Shoulder', # 8 + 'Right Shoulder', # 8 'Left Shoulder', - 'Left Elbow', # 10 + 'Left Elbow', # 10 'Left Wrist', - 'Neck (LSP)', # 12 + 'Neck (LSP)', # 12 'Top of Head (LSP)', - 'Pelvis (MPII)', # 14 + 'Pelvis (MPII)', # 14 'Thorax (MPII)', - 'Spine (H36M)', # 16 + 'Spine (H36M)', # 16 'Jaw (H36M)', - 'Head (H36M)', # 18 + 'Head (H36M)', # 18 'Nose', 'Left Eye', 'Right Eye', @@ -278,8 +278,8 @@ FACIAL_LANDMARKS = [ 'left_mouth_3', 'left_mouth_2', 'left_mouth_1', - 'left_mouth_5', # 59 in OpenPose output - 'left_mouth_4', # 58 in OpenPose output + 'left_mouth_5', # 59 in OpenPose output + 'left_mouth_4', # 58 in OpenPose output 'mouth_bottom', 'right_mouth_4', 'right_mouth_5', diff --git a/lib/pymafx/models/attention.py b/lib/pymafx/models/attention.py index 21bf6d10d907546e5462bd896e8d8bb819e41b24..b0f7d3c5c63ba1471ff15ee1a3cf0d8c94a17699 100644 --- a/lib/pymafx/models/attention.py +++ b/lib/pymafx/models/attention.py @@ -16,6 +16,7 @@ from .transformers.bert.modeling_bert import BertPreTrainedModel, BertEmbeddings # import src.modeling.data.config as cfg # from src.modeling._gcnn import GraphConvolution, GraphResBlock from .transformers.bert.modeling_utils import prune_linear_layer + LayerNormClass = torch.nn.LayerNorm BertLayerNorm = torch.nn.LayerNorm from .transformers.bert import BertConfig @@ -27,7 +28,8 @@ class BertSelfAttention(nn.Module): if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) self.output_attentions = config.output_attentions self.num_attention_heads = config.num_attention_heads @@ -45,8 +47,7 @@ class BertSelfAttention(nn.Module): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask, head_mask=None, - history_state=None): + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): if history_state is not None: raise x_states = torch.cat([history_state, hidden_states], dim=1) @@ -85,12 +86,13 @@ class BertSelfAttention(nn.Module): context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, ) return outputs + class BertAttention(nn.Module): def __init__(self, config): super(BertAttention, self).__init__() @@ -114,12 +116,10 @@ class BertAttention(nn.Module): self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - def forward(self, input_tensor, attention_mask, head_mask=None, - history_state=None): - self_outputs = self.self(input_tensor, attention_mask, head_mask, - history_state) + def forward(self, input_tensor, attention_mask, head_mask=None, history_state=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state) attention_output = self.output(self_outputs[0], input_tensor) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them return outputs @@ -131,10 +131,8 @@ class AttLayer(nn.Module): self.intermediate = BertIntermediate(config) self.output = BertOutput(config) - def MHA(self, hidden_states, attention_mask, head_mask=None, - history_state=None): - attention_outputs = self.attention(hidden_states, attention_mask, - head_mask, history_state) + def MHA(self, hidden_states, attention_mask, head_mask=None, history_state=None): + attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_state) attention_output = attention_outputs[0] # print('attention_output', hidden_states.shape, attention_output.shape) @@ -143,12 +141,11 @@ class AttLayer(nn.Module): # print('intermediate_output', intermediate_output.shape) layer_output = self.output(intermediate_output, attention_output) # print('layer_output', layer_output.shape) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them return outputs - def forward(self, hidden_states, attention_mask, head_mask=None, - history_state=None): - return self.MHA(hidden_states, attention_mask, head_mask,history_state) + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): + return self.MHA(hidden_states, attention_mask, head_mask, history_state) class AttEncoder(nn.Module): @@ -158,34 +155,32 @@ class AttEncoder(nn.Module): self.output_hidden_states = config.output_hidden_states self.layer = nn.ModuleList([AttLayer(config) for _ in range(config.num_hidden_layers)]) - def forward(self, hidden_states, attention_mask, head_mask=None, - encoder_history_states=None): + def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None): all_hidden_states = () all_attentions = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) history_state = None if encoder_history_states is None else encoder_history_states[i] - layer_outputs = layer_module( - hidden_states, attention_mask, head_mask[i], - history_state) + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], history_state) hidden_states = layer_outputs[0] if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1], ) # Add last layer if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) - outputs = (hidden_states,) + outputs = (hidden_states, ) if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) + outputs = outputs + (all_hidden_states, ) if self.output_attentions: - outputs = outputs + (all_attentions,) + outputs = outputs + (all_attentions, ) + + return outputs # outputs, (hidden states), (attentions) - return outputs # outputs, (hidden states), (attentions) class EncoderBlock(BertPreTrainedModel): def __init__(self, config): @@ -195,7 +190,7 @@ class EncoderBlock(BertPreTrainedModel): self.encoder = AttEncoder(config) # self.pooler = BertPooler(config) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.img_dim = config.img_feature_dim + self.img_dim = config.img_feature_dim try: self.use_img_layernorm = config.use_img_layernorm @@ -217,26 +212,32 @@ class EncoderBlock(BertPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, - position_ids=None, head_mask=None): + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None + ): batch_size = len(img_feats) seq_length = len(img_feats[0]) - input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(img_feats.device) + input_ids = torch.zeros([batch_size, seq_length], dtype=torch.long).to(img_feats.device) if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # print('-------------------') # print('position_ids', seq_length, position_ids.shape) - # 494 torch.Size([2, 494]) + # 494 torch.Size([2, 494]) position_embeddings = self.position_embeddings(position_ids) # print('position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, self.config.hidden_size) - # torch.Size([2, 494, 1024]) 512 1024 + # torch.Size([2, 494, 1024]) 512 1024 # torch.Size([2, 494, 256]) 512 256 - if attention_mask is None: attention_mask = torch.ones_like(input_ids) else: @@ -255,7 +256,9 @@ class EncoderBlock(BertPreTrainedModel): raise NotImplementedError # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = extended_attention_mask.to(dtype=img_feats.dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to( + dtype=img_feats.dtype + ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 if head_mask is not None: @@ -264,15 +267,19 @@ class EncoderBlock(BertPreTrainedModel): head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) elif head_mask.dim() == 2: - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer - head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1 + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility else: head_mask = [None] * self.config.num_hidden_layers # Project input token features to have spcified hidden size - # print('img_feats', img_feats.shape) # torch.Size([2, 494, 2051]) + # print('img_feats', img_feats.shape) # torch.Size([2, 494, 2051]) img_embedding_output = self.img_embedding(img_feats) - # print('img_embedding_output', img_embedding_output.shape) # torch.Size([2, 494, 1024]) + # print('img_embedding_output', img_embedding_output.shape) # torch.Size([2, 494, 1024]) # We empirically observe that adding an additional learnable position embedding leads to more stable training embeddings = position_embeddings + img_embedding_output @@ -282,21 +289,27 @@ class EncoderBlock(BertPreTrainedModel): # embeddings = self.dropout(embeddings) # print('extended_attention_mask', extended_attention_mask.shape) # torch.Size([2, 1, 1, 494]) - encoder_outputs = self.encoder(embeddings, - extended_attention_mask, head_mask=head_mask) + encoder_outputs = self.encoder(embeddings, extended_attention_mask, head_mask=head_mask) sequence_output = encoder_outputs[0] - outputs = (sequence_output,) + outputs = (sequence_output, ) if self.config.output_hidden_states: all_hidden_states = encoder_outputs[1] - outputs = outputs + (all_hidden_states,) + outputs = outputs + (all_hidden_states, ) if self.config.output_attentions: all_attentions = encoder_outputs[-1] - outputs = outputs + (all_attentions,) + outputs = outputs + (all_attentions, ) return outputs -def get_att_block(img_feature_dim=2048, output_feat_dim=512, hidden_feat_dim=1024, num_attention_heads=4, num_hidden_layers=1): + +def get_att_block( + img_feature_dim=2048, + output_feat_dim=512, + hidden_feat_dim=1024, + num_attention_heads=4, + num_hidden_layers=1 +): config_class = BertConfig config = config_class.from_pretrained('lib/pymafx/models/transformers/bert/bert-base-uncased/') @@ -316,7 +329,7 @@ def get_att_block(img_feature_dim=2048, output_feat_dim=512, hidden_feat_dim=102 # init a transformer encoder and append it to a list assert config.hidden_size % config.num_attention_heads == 0 - att_model = EncoderBlock(config=config) + att_model = EncoderBlock(config=config) return att_model @@ -333,16 +346,31 @@ class Graphormer(BertPreTrainedModel): self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) self.apply(self.init_weights) - def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, - next_sentence_label=None, position_ids=None, head_mask=None): + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): ''' # self.bert has three outputs # predictions[0]: output tokens # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" # predictions[2]: attentions, if enable "self.config.output_attentions" ''' - predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + predictions = self.bert( + img_feats=img_feats, + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification. pred_score = self.cls_head(predictions[0]) @@ -354,5 +382,3 @@ class Graphormer(BertPreTrainedModel): return pred_score, predictions[1], predictions[-1] else: return pred_score - - \ No newline at end of file diff --git a/lib/pymafx/models/hmr.py b/lib/pymafx/models/hmr.py index f91f4a8311b940afca6155d2f31487c6e77fa5ad..da5459d355d3a3f00c53638a376ab3143b23c01e 100755 --- a/lib/pymafx/models/hmr.py +++ b/lib/pymafx/models/hmr.py @@ -8,10 +8,12 @@ import math from lib.net.geometry import rot6d_to_rotmat import logging + logger = logging.getLogger(__name__) BN_MOMENTUM = 0.1 + class Bottleneck(nn.Module): """ Redefinition of Bottleneck residual block Adapted from the official PyTorch implementation @@ -22,8 +24,7 @@ class Bottleneck(nn.Module): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -57,18 +58,16 @@ class Bottleneck(nn.Module): class ResNet_Backbone(nn.Module): """ Feature Extrator with ResNet backbone """ - def __init__(self, model='res50', pretrained=True): if model == 'res50': block, layers = Bottleneck, [3, 4, 6, 3] else: - pass # TODO + pass # TODO self.inplanes = 64 super().__init__() npose = 24 * 6 - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -87,8 +86,13 @@ class ResNet_Backbone(nn.Module): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), nn.BatchNorm2d(planes * block.expansion), ) @@ -105,7 +109,7 @@ class ResNet_Backbone(nn.Module): 'ERROR: num_deconv_layers is different len(num_deconv_filters)' assert num_layers == len(num_kernels), \ 'ERROR: num_deconv_layers is different len(num_deconv_filters)' - + def _get_deconv_cfg(deconv_kernel, index): if deconv_kernel == 4: padding = 1 @@ -132,7 +136,9 @@ class ResNet_Backbone(nn.Module): stride=2, padding=padding, output_padding=output_padding, - bias=self.deconv_with_bias)) + bias=self.deconv_with_bias + ) + ) layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) layers.append(nn.ReLU(inplace=True)) self.inplanes = planes @@ -164,13 +170,11 @@ class ResNet_Backbone(nn.Module): class HMR(nn.Module): """ SMPL Iterative Regressor with ResNet50 backbone """ - def __init__(self, block, layers, smpl_mean_params): self.inplanes = 64 super().__init__() npose = 24 * 6 - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -206,13 +210,17 @@ class HMR(nn.Module): self.register_buffer('init_shape', init_shape) self.register_buffer('init_cam', init_cam) - def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), nn.BatchNorm2d(planes * block.expansion), ) @@ -224,7 +232,6 @@ class HMR(nn.Module): return nn.Sequential(*layers) - def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3): batch_size = x.shape[0] @@ -253,7 +260,7 @@ class HMR(nn.Module): pred_shape = init_shape pred_cam = init_cam for i in range(n_iter): - xc = torch.cat([xf, pred_pose, pred_shape, pred_cam],1) + xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1) xc = self.fc1(xc) xc = self.drop1(xc) xc = self.fc2(xc) @@ -266,13 +273,14 @@ class HMR(nn.Module): return pred_rotmat, pred_shape, pred_cam + def hmr(smpl_mean_params, pretrained=True, **kwargs): """ Constructs an HMR model with ResNet50 backbone. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs) + model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs) if pretrained: resnet_imagenet = resnet.resnet50(pretrained=True) - model.load_state_dict(resnet_imagenet.state_dict(),strict=False) - return model \ No newline at end of file + model.load_state_dict(resnet_imagenet.state_dict(), strict=False) + return model diff --git a/lib/pymafx/models/hr_module.py b/lib/pymafx/models/hr_module.py index 285cd2c56728e439fdcd1a8bccbafca3f5549ef3..7396f1ea59860235db8fdd24434114381c4a7083 100644 --- a/lib/pymafx/models/hr_module.py +++ b/lib/pymafx/models/hr_module.py @@ -7,16 +7,25 @@ import torch.nn.functional as F from .res_module import BasicBlock, Bottleneck import logging + logger = logging.getLogger(__name__) BN_MOMENTUM = 0.1 + class HighResolutionModule(nn.Module): - def __init__(self, num_branches, blocks, num_blocks, num_inchannels, - num_channels, fuse_method, multi_scale_output=True): + def __init__( + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True + ): super().__init__() - self._check_branches( - num_branches, blocks, num_blocks, num_inchannels, num_channels) + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method @@ -24,33 +33,31 @@ class HighResolutionModule(nn.Module): self.multi_scale_output = multi_scale_output - self.branches = self._make_branches( - num_branches, blocks, num_blocks, num_channels) + self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(True) - def _check_branches(self, num_branches, blocks, num_blocks, - num_inchannels, num_channels): + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): - error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( - num_branches, len(num_blocks)) + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_channels): error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( - num_branches, len(num_channels)) + num_branches, len(num_channels) + ) logger.error(error_msg) raise ValueError(error_msg) if num_branches != len(num_inchannels): error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( - num_branches, len(num_inchannels)) + num_branches, len(num_inchannels) + ) logger.error(error_msg) raise ValueError(error_msg) - def _make_one_branch(self, branch_index, block, num_blocks, num_channels, - stride=1): + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if stride != 1 or \ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: @@ -58,32 +65,23 @@ class HighResolutionModule(nn.Module): nn.Conv2d( self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, - kernel_size=1, stride=stride, bias=False - ), - nn.BatchNorm2d( - num_channels[branch_index] * block.expansion, - momentum=BN_MOMENTUM + kernel_size=1, + stride=stride, + bias=False ), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), ) layers = [] layers.append( block( - self.num_inchannels[branch_index], - num_channels[branch_index], - stride, - downsample + self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample ) ) self.num_inchannels[branch_index] = \ num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): - layers.append( - block( - self.num_inchannels[branch_index], - num_channels[branch_index] - ) - ) + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) @@ -91,9 +89,7 @@ class HighResolutionModule(nn.Module): branches = [] for i in range(num_branches): - branches.append( - self._make_one_branch(i, block, num_blocks, num_channels) - ) + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) @@ -110,20 +106,16 @@ class HighResolutionModule(nn.Module): if j > i: fuse_layer.append( nn.Sequential( - nn.Conv2d( - num_inchannels[j], - num_inchannels[i], - 1, 1, 0, bias=False - ), + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), nn.BatchNorm2d(num_inchannels[i]), - nn.Upsample(scale_factor=2**(j-i), mode='nearest') + nn.Upsample(scale_factor=2**(j - i), mode='nearest') ) ) elif j == i: fuse_layer.append(None) else: conv3x3s = [] - for k in range(i-j): + for k in range(i - j): if k == i - j - 1: num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append( @@ -131,9 +123,11 @@ class HighResolutionModule(nn.Module): nn.Conv2d( num_inchannels[j], num_outchannels_conv3x3, - 3, 2, 1, bias=False - ), - nn.BatchNorm2d(num_outchannels_conv3x3) + 3, + 2, + 1, + bias=False + ), nn.BatchNorm2d(num_outchannels_conv3x3) ) ) else: @@ -143,10 +137,11 @@ class HighResolutionModule(nn.Module): nn.Conv2d( num_inchannels[j], num_outchannels_conv3x3, - 3, 2, 1, bias=False - ), - nn.BatchNorm2d(num_outchannels_conv3x3), - nn.ReLU(True) + 3, + 2, + 1, + bias=False + ), nn.BatchNorm2d(num_outchannels_conv3x3), nn.ReLU(True) ) ) fuse_layer.append(nn.Sequential(*conv3x3s)) @@ -178,25 +173,19 @@ class HighResolutionModule(nn.Module): return x_fuse -blocks_dict = { - 'BASIC': BasicBlock, - 'BOTTLENECK': Bottleneck -} +blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} class PoseHighResolutionNet(nn.Module): - def __init__(self, cfg, pretrained=True, global_mode=False): self.inplanes = 64 extra = cfg.HR_MODEL.EXTRA super().__init__() # stem net - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, - bias=False) + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) - self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, - bias=False) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(Bottleneck, self.inplanes, 64, 4) @@ -204,34 +193,25 @@ class PoseHighResolutionNet(nn.Module): self.stage2_cfg = cfg['HR_MODEL']['EXTRA']['STAGE2'] num_channels = self.stage2_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage2_cfg['BLOCK']] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer([256], num_channels) - self.stage2, pre_stage_channels = self._make_stage( - self.stage2_cfg, num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) self.stage3_cfg = cfg['HR_MODEL']['EXTRA']['STAGE3'] num_channels = self.stage3_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage3_cfg['BLOCK']] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] - self.transition2 = self._make_transition_layer( - pre_stage_channels, num_channels) - self.stage3, pre_stage_channels = self._make_stage( - self.stage3_cfg, num_channels) + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) self.stage4_cfg = cfg['HR_MODEL']['EXTRA']['STAGE4'] num_channels = self.stage4_cfg['NUM_CHANNELS'] block = blocks_dict[self.stage4_cfg['BLOCK']] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] - self.transition3 = self._make_transition_layer( - pre_stage_channels, num_channels) + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( - self.stage4_cfg, num_channels, multi_scale_output=True) + self.stage4_cfg, num_channels, multi_scale_output=True + ) # Classification Head self.global_mode = global_mode @@ -249,11 +229,7 @@ class PoseHighResolutionNet(nn.Module): # from C, 2C, 4C, 8C to 128, 256, 512, 1024 incre_modules = [] for i, channels in enumerate(pre_stage_channels): - incre_module = self._make_layer(head_block, - channels, - head_channels[i], - 1, - stride=1) + incre_module = self._make_layer(head_block, channels, head_channels[i], 1, stride=1) incre_modules.append(incre_module) incre_modules = nn.ModuleList(incre_modules) @@ -264,13 +240,13 @@ class PoseHighResolutionNet(nn.Module): out_channels = head_channels[i + 1] * head_block.expansion downsamp_module = nn.Sequential( - nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - stride=2, - padding=1), - nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), - nn.ReLU(inplace=True) + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1 + ), nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), nn.ReLU(inplace=True) ) downsamp_modules.append(downsamp_module) @@ -283,15 +259,12 @@ class PoseHighResolutionNet(nn.Module): kernel_size=1, stride=1, padding=0 - ), - nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), - nn.ReLU(inplace=True) + ), nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), nn.ReLU(inplace=True) ) return incre_modules, downsamp_modules, final_layer - def _make_transition_layer( - self, num_channels_pre_layer, num_channels_cur_layer): + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) @@ -304,27 +277,25 @@ class PoseHighResolutionNet(nn.Module): nn.Conv2d( num_channels_pre_layer[i], num_channels_cur_layer[i], - 3, 1, 1, bias=False - ), - nn.BatchNorm2d(num_channels_cur_layer[i]), - nn.ReLU(inplace=True) + 3, + 1, + 1, + bias=False + ), nn.BatchNorm2d(num_channels_cur_layer[i]), nn.ReLU(inplace=True) ) ) else: transition_layers.append(None) else: conv3x3s = [] - for j in range(i+1-num_branches_pre): + for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] outchannels = num_channels_cur_layer[i] \ if j == i-num_branches_pre else inchannels conv3x3s.append( nn.Sequential( - nn.Conv2d( - inchannels, outchannels, 3, 2, 1, bias=False - ), - nn.BatchNorm2d(outchannels), - nn.ReLU(inplace=True) + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels), nn.ReLU(inplace=True) ) ) transition_layers.append(nn.Sequential(*conv3x3s)) @@ -336,8 +307,7 @@ class PoseHighResolutionNet(nn.Module): if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( - inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False + inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) @@ -350,8 +320,7 @@ class PoseHighResolutionNet(nn.Module): return nn.Sequential(*layers) - def _make_stage(self, layer_config, num_inchannels, - multi_scale_output=True): + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): num_modules = layer_config['NUM_MODULES'] num_branches = layer_config['NUM_BRANCHES'] num_blocks = layer_config['NUM_BLOCKS'] @@ -369,12 +338,7 @@ class PoseHighResolutionNet(nn.Module): modules.append( HighResolutionModule( - num_branches, - block, - num_blocks, - num_inchannels, - num_channels, - fuse_method, + num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output ) ) diff --git a/lib/pymafx/models/maf_extractor.py b/lib/pymafx/models/maf_extractor.py index 1d1af8dac7a4bf9c157f9396ffe32f7811ec50ad..34237bc55663dcbcbd67beb4c5d0b6e693aae266 100644 --- a/lib/pymafx/models/maf_extractor.py +++ b/lib/pymafx/models/maf_extractor.py @@ -10,6 +10,7 @@ from lib.pymafx.core import path_config from lib.pymafx.utils.geometry import projection import logging + logger = logging.getLogger(__name__) from .transformers.net_utils import PosEnSine @@ -19,7 +20,9 @@ from lib.pymafx.utils.imutils import j2d_processing class TransformerDecoderUnit(nn.Module): - def __init__(self, feat_dim, attri_dim=0, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): + def __init__( + self, feat_dim, attri_dim=0, n_head=8, pos_en_flag=True, attn_type='softmax', P=None + ): super(TransformerDecoderUnit, self).__init__() self.feat_dim = feat_dim self.attn_type = attn_type @@ -32,7 +35,9 @@ class TransformerDecoderUnit(nn.Module): self.pos_en = PosEnSine(pe_dim) else: pe_dim = 0 - self.attn = OurMultiheadAttention(feat_dim+attri_dim+pe_dim*3, feat_dim+pe_dim*3, feat_dim, n_head) # cross-attention + self.attn = OurMultiheadAttention( + feat_dim + attri_dim + pe_dim * 3, feat_dim + pe_dim * 3, feat_dim, n_head + ) # cross-attention self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) @@ -50,7 +55,7 @@ class TransformerDecoderUnit(nn.Module): # else: # q_pos_embed = 0 # k_pos_embed = 0 - + # cross-multi-head attention out = self.attn(q=q, k=k, v=v, attn_type=self.attn_type, P=self.P)[0] @@ -65,25 +70,28 @@ class TransformerDecoderUnit(nn.Module): class Mesh_Sampler(nn.Module): ''' Mesh Up/Down-sampling ''' - def __init__(self, type='smpl', level=2, device=torch.device('cuda'), option=None): super().__init__() # downsample SMPL mesh and assign part labels if type == 'smpl': # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz - smpl_mesh_graph = np.load(path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1') + smpl_mesh_graph = np.load( + path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1' + ) A = smpl_mesh_graph['A'] U = smpl_mesh_graph['U'] - D = smpl_mesh_graph['D'] # shape: (2,) + D = smpl_mesh_graph['D'] # shape: (2,) elif type == 'mano': # from https://github.com/microsoft/MeshGraphormer/blob/main/src/modeling/data/mano_downsampling.npz - mano_mesh_graph = np.load(path_config.MANO_DOWNSAMPLING, allow_pickle=True, encoding='latin1') + mano_mesh_graph = np.load( + path_config.MANO_DOWNSAMPLING, allow_pickle=True, encoding='latin1' + ) A = mano_mesh_graph['A'] U = mano_mesh_graph['U'] - D = mano_mesh_graph['D'] # shape: (2,) + D = mano_mesh_graph['D'] # shape: (2,) # downsampling ptD = [] @@ -92,14 +100,14 @@ class Mesh_Sampler(nn.Module): i = torch.LongTensor(np.array([d.row, d.col])) v = torch.FloatTensor(d.data) ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) - + # downsampling mapping from 6890 points to 431 points # ptD[0].to_dense() - Size: [1723, 6890] , [195, 778] # ptD[1].to_dense() - Size: [431, 1723] , [49, 195] if level == 2: - Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431 + Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431 elif level == 1: - Dmap = ptD[0].to_dense() # + Dmap = ptD[0].to_dense() # self.register_buffer('Dmap', Dmap) # upsampling @@ -109,21 +117,21 @@ class Mesh_Sampler(nn.Module): i = torch.LongTensor(np.array([d.row, d.col])) v = torch.FloatTensor(d.data) ptU.append(torch.sparse.FloatTensor(i, v, d.shape)) - + # upsampling mapping from 431 points to 6890 points # ptU[0].to_dense() - Size: [6890, 1723] # ptU[1].to_dense() - Size: [1723, 431] if level == 2: - Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890 + Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890 elif level == 1: - Umap = ptU[0].to_dense() # + Umap = ptU[0].to_dense() # self.register_buffer('Umap', Umap) def downsample(self, x): - return torch.matmul(self.Dmap.unsqueeze(0), x) # [B, 431, 3] - + return torch.matmul(self.Dmap.unsqueeze(0), x) # [B, 431, 3] + def upsample(self, x): - return torch.matmul(self.Umap.unsqueeze(0), x) # [B, 6890, 3] + return torch.matmul(self.Umap.unsqueeze(0), x) # [B, 6890, 3] def forward(self, x, mode='downsample'): if mode == 'downsample': @@ -137,8 +145,9 @@ class MAF_Extractor(nn.Module): As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices. The features extrated from spatial feature maps will go through a MLP for dimension reduction. ''' - - def __init__(self, filter_channels, device=torch.device('cuda'), iwp_cam_mode=True, option=None): + def __init__( + self, filter_channels, device=torch.device('cuda'), iwp_cam_mode=True, option=None + ): super().__init__() self.device = device @@ -151,25 +160,22 @@ class MAF_Extractor(nn.Module): for l in range(0, len(filter_channels) - 1): if 0 != l: self.filters.append( - nn.Conv1d( - filter_channels[l] + filter_channels[0], - filter_channels[l + 1], - 1)) + nn.Conv1d(filter_channels[l] + filter_channels[0], filter_channels[l + 1], 1) + ) else: - self.filters.append(nn.Conv1d( - filter_channels[l], - filter_channels[l + 1], - 1)) + self.filters.append(nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) self.add_module("conv%d" % l, self.filters[l]) # downsample SMPL mesh and assign part labels # from https://github.com/nkolot/GraphCMR/blob/master/data/mesh_downsampling.npz - smpl_mesh_graph = np.load(path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1') + smpl_mesh_graph = np.load( + path_config.SMPL_DOWNSAMPLING, allow_pickle=True, encoding='latin1' + ) A = smpl_mesh_graph['A'] U = smpl_mesh_graph['U'] - D = smpl_mesh_graph['D'] # shape: (2,) + D = smpl_mesh_graph['D'] # shape: (2,) # downsampling ptD = [] @@ -178,11 +184,11 @@ class MAF_Extractor(nn.Module): i = torch.LongTensor(np.array([d.row, d.col])) v = torch.FloatTensor(d.data) ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) - + # downsampling mapping from 6890 points to 431 points # ptD[0].to_dense() - Size: [1723, 6890] # ptD[1].to_dense() - Size: [431. 1723] - Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431 + Dmap = torch.matmul(ptD[1].to_dense(), ptD[0].to_dense()) # 6890 -> 431 self.register_buffer('Dmap', Dmap) # upsampling @@ -192,14 +198,13 @@ class MAF_Extractor(nn.Module): i = torch.LongTensor(np.array([d.row, d.col])) v = torch.FloatTensor(d.data) ptU.append(torch.sparse.FloatTensor(i, v, d.shape)) - + # upsampling mapping from 431 points to 6890 points # ptU[0].to_dense() - Size: [6890, 1723] # ptU[1].to_dense() - Size: [1723, 431] - Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890 + Umap = torch.matmul(ptU[0].to_dense(), ptU[1].to_dense()) # 431 -> 6890 self.register_buffer('Umap', Umap) - def reduce_dim(self, feature): ''' Dimension reduction by multi-layer perceptrons @@ -209,19 +214,13 @@ class MAF_Extractor(nn.Module): y = feature tmpy = feature for i, f in enumerate(self.filters): - y = self._modules['conv' + str(i)]( - y if i == 0 - else torch.cat([y, tmpy], 1) - ) + y = self._modules['conv' + str(i)](y if i == 0 else torch.cat([y, tmpy], 1)) if i != len(self.filters) - 1: y = F.leaky_relu(y) if self.num_views > 1 and i == len(self.filters) // 2: - y = y.view( - -1, self.num_views, y.shape[1], y.shape[2] - ).mean(dim=1) - tmpy = feature.view( - -1, self.num_views, feature.shape[1], feature.shape[2] - ).mean(dim=1) + y = y.view(-1, self.num_views, y.shape[1], y.shape[2]).mean(dim=1) + tmpy = feature.view(-1, self.num_views, feature.shape[1], + feature.shape[2]).mean(dim=1) y = self.last_op(y) @@ -242,7 +241,9 @@ class MAF_Extractor(nn.Module): # im_feat = self.im_feat batch_size = im_feat.shape[0] - point_feat = torch.nn.functional.grid_sample(im_feat, points.unsqueeze(2), align_corners=False)[..., 0] + point_feat = torch.nn.functional.grid_sample( + im_feat, points.unsqueeze(2), align_corners=False + )[..., 0] if reduce_dim: mesh_align_feat = self.reduce_dim(point_feat) @@ -266,6 +267,6 @@ class MAF_Extractor(nn.Module): # Normalize keypoints to [-1,1] p_proj_2d = p_proj_2d / (224. / 2.) else: - p_proj_2d = j2d_processing(p_proj_2d, cam['kps_transf']) + p_proj_2d = j2d_processing(p_proj_2d, cam['kps_transf']) mesh_align_feat = self.sampling(p_proj_2d, im_feat, add_att=add_att, reduce_dim=reduce_dim) return mesh_align_feat diff --git a/lib/pymafx/models/pose_resnet.py b/lib/pymafx/models/pose_resnet.py index e9a2f6716c002b2fd9645d1877081b4177730049..d97b6609cf02fd2a94d2951f82f71de2be2356c0 100644 --- a/lib/pymafx/models/pose_resnet.py +++ b/lib/pymafx/models/pose_resnet.py @@ -14,17 +14,13 @@ import logging import torch import torch.nn as nn - BN_MOMENTUM = 0.1 logger = logging.getLogger(__name__) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False - ) + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): @@ -66,13 +62,10 @@ class Bottleneck(nn.Module): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) - self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, - bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion, - momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -101,7 +94,6 @@ class Bottleneck(nn.Module): class PoseResNet(nn.Module): - def __init__(self, block, layers, cfg, global_mode, **kwargs): self.inplanes = 64 extra = cfg.POSE_RES_MODEL.EXTRA @@ -109,8 +101,7 @@ class PoseResNet(nn.Module): self.deconv_with_bias = extra.DECONV_WITH_BIAS super(PoseResNet, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -144,8 +135,13 @@ class PoseResNet(nn.Module): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) @@ -190,7 +186,9 @@ class PoseResNet(nn.Module): stride=2, padding=padding, output_padding=output_padding, - bias=self.deconv_with_bias)) + bias=self.deconv_with_bias + ) + ) layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) layers.append(nn.ReLU(inplace=True)) self.inplanes = planes @@ -218,7 +216,9 @@ class PoseResNet(nn.Module): else: g_feat = None if self.extra.NUM_DECONV_LAYERS == 3: - deconv_blocks = [self.deconv_layers[0:3], self.deconv_layers[3:6], self.deconv_layers[6:9]] + deconv_blocks = [ + self.deconv_layers[0:3], self.deconv_layers[3:6], self.deconv_layers[6:9] + ] s_feat_list = [] s_feat = x @@ -284,6 +284,7 @@ resnet_spec = { 152: (Bottleneck, [3, 8, 36, 3]) } + def get_resnet_encoder(cfg, init_weight=True, global_mode=False, **kwargs): num_layers = cfg.POSE_RES_MODEL.EXTRA.NUM_LAYERS diff --git a/lib/pymafx/models/pymaf_net.py b/lib/pymafx/models/pymaf_net.py index 5b6f3587e5c470236a0647550b64309058545d1b..ca57e4b1c8ce971d76ce53d02827f441016a19ab 100644 --- a/lib/pymafx/models/pymaf_net.py +++ b/lib/pymafx/models/pymaf_net.py @@ -23,15 +23,16 @@ BN_MOMENTUM = 0.1 class Regressor(nn.Module): - - def __init__(self, - feat_dim, - smpl_mean_params, - use_cam_feats=False, - feat_dim_hand=0, - feat_dim_face=0, - bhf_names=['body'], - smpl_models={}): + def __init__( + self, + feat_dim, + smpl_mean_params, + use_cam_feats=False, + feat_dim_hand=0, + feat_dim_face=0, + bhf_names=['body'], + smpl_models={} + ): super().__init__() npose = 24 * 6 @@ -96,8 +97,9 @@ class Regressor(nn.Module): rh_cam_dim = 3 rh_orient_dim = 6 rh_shape_dim = 10 - self.fc3_hand = nn.Linear(1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, - 1024) + self.fc3_hand = nn.Linear( + 1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024 + ) self.drop3_hand = nn.Dropout() self.decshape_rhand = nn.Linear(1024, 10) @@ -122,8 +124,9 @@ class Regressor(nn.Module): rh_cam_dim = 3 rh_orient_dim = 6 rh_shape_dim = 10 - self.fc3_face = nn.Linear(1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, - 1024) + self.fc3_face = nn.Linear( + 1024 + rh_orient_dim + rh_shape_dim + rh_cam_dim, 1024 + ) self.drop3_face = nn.Dropout() self.decshape_face = nn.Linear(1024, 10) @@ -167,10 +170,14 @@ class Regressor(nn.Module): if not self.smpl_mode: lhand_mean_rot6d = rotmat_to_rot6d( batch_rodrigues(self.smpl.model.model_neutral.left_hand_mean.view(-1, 3)).view( - [-1, 3, 3])) + [-1, 3, 3] + ) + ) rhand_mean_rot6d = rotmat_to_rot6d( batch_rodrigues(self.smpl.model.model_neutral.right_hand_mean.view(-1, 3)).view( - [-1, 3, 3])) + [-1, 3, 3] + ) + ) init_lhand = lhand_mean_rot6d.reshape(-1).unsqueeze(0) init_rhand = rhand_mean_rot6d.reshape(-1).unsqueeze(0) # init_hand = torch.cat([init_lhand, init_rhand]).unsqueeze(0) @@ -185,14 +192,16 @@ class Regressor(nn.Module): self.register_buffer('init_face', init_face) self.register_buffer('init_exp', init_exp) - def forward(self, - x=None, - n_iter=1, - J_regressor=None, - rw_cam={}, - init_mode=False, - global_iter=-1, - **kwargs): + def forward( + self, + x=None, + n_iter=1, + J_regressor=None, + rw_cam={}, + init_mode=False, + global_iter=-1, + **kwargs + ): if x is not None: batch_size = x.shape[0] else: @@ -215,8 +224,9 @@ class Regressor(nn.Module): if self.full_body_mode or self.body_hand_mode: if cfg.MODEL.PyMAF.OPT_WRIST: - pred_rotmat_body = rot6d_to_rotmat(pred_pose.reshape( - batch_size, -1, 6)) # .view(batch_size, 24, 3, 3) + pred_rotmat_body = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) if cfg.MODEL.PyMAF.PRED_VIS_H: pred_vis_hands = None @@ -291,7 +301,8 @@ class Regressor(nn.Module): vfov = rw_cam['vfov'][:, None] crop_ratio = rw_cam['crop_ratio'][:, None] crop_center = rw_cam['bbox_center'] / torch.cat( - [rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]], 1) + [rw_cam['img_w'][:, None], rw_cam['img_h'][:, None]], 1 + ) xc = torch.cat([xc, vfov, crop_ratio, crop_center], 1) xc = self.fc1(xc) @@ -315,8 +326,8 @@ class Regressor(nn.Module): xc_lhand = torch.cat([xc_lhand, pred_lhand], 1) xc_rhand = torch.cat([xc_rhand, pred_rhand], 1) elif self.full_body_mode: - xc_lhand, xc_rhand, xc_face = kwargs['xc_lhand'], kwargs[ - 'xc_rhand'], kwargs['xc_face'] + xc_lhand, xc_rhand, xc_face = kwargs['xc_lhand'], kwargs['xc_rhand' + ], kwargs['xc_face'] xc_lhand = torch.cat([xc_lhand, pred_lhand], 1) xc_rhand = torch.cat([xc_rhand, pred_rhand], 1) xc_face = torch.cat([xc_face, pred_face, pred_exp], 1) @@ -328,7 +339,8 @@ class Regressor(nn.Module): if cfg.MODEL.PyMAF.OPT_WRIST: xc_lhand = torch.cat( - [xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh], 1) + [xc_lhand, pred_shape_lh, pred_orient_lh, pred_cam_lh], 1 + ) xc_lhand = self.drop3_hand(self.fc3_hand(xc_lhand)) pred_shape_lh = self.decshape_rhand(xc_lhand) + pred_shape_lh @@ -342,7 +354,8 @@ class Regressor(nn.Module): if cfg.MODEL.MESH_MODEL == 'mano' or cfg.MODEL.PyMAF.OPT_WRIST: xc_rhand = torch.cat( - [xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh], 1) + [xc_rhand, pred_shape_rh, pred_orient_rh, pred_cam_rh], 1 + ) xc_rhand = self.drop3_hand(self.fc3_hand(xc_rhand)) pred_shape_rh = self.decshape_rhand(xc_rhand) + pred_shape_rh @@ -351,7 +364,8 @@ class Regressor(nn.Module): if cfg.MODEL.MESH_MODEL == 'mano': pred_cam = torch.cat( - [pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.], dim=1) + [pred_cam_rh[:, 0:1] * 10., pred_cam_rh[:, 1:] / 10.], dim=1 + ) if 'face' in self.part_names: xc_face = self.drop1_face(self.fc1_face(xc_face)) @@ -361,7 +375,8 @@ class Regressor(nn.Module): if cfg.MODEL.MESH_MODEL == 'flame': xc_face = torch.cat( - [xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa], 1) + [xc_face, pred_shape_fa, pred_orient_fa, pred_cam_fa], 1 + ) xc_face = self.drop3_face(self.fc3_face(xc_face)) pred_shape_fa = self.decshape_face(xc_face) + pred_shape_fa @@ -370,7 +385,8 @@ class Regressor(nn.Module): if cfg.MODEL.MESH_MODEL == 'flame': pred_cam = torch.cat( - [pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.], dim=1) + [pred_cam_fa[:, 0:1] * 10., pred_cam_fa[:, 1:] / 10.], dim=1 + ) if self.full_body_mode or self.body_hand_mode: if cfg.MODEL.PyMAF.PRED_VIS_H: @@ -385,22 +401,26 @@ class Regressor(nn.Module): if cfg.MODEL.PyMAF.OPT_WRIST: - pred_rotmat_body = rot6d_to_rotmat(pred_pose.reshape( - batch_size, -1, 6)) # .view(batch_size, 24, 3, 3) + pred_rotmat_body = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) pred_lwrist = pred_rotmat_body[:, 20] pred_rwrist = pred_rotmat_body[:, 21] pred_gl_body, body_joints = self.body_model.get_global_rotation( global_orient=pred_rotmat_body[:, 0:1], - body_pose=pred_rotmat_body[:, 1:]) + body_pose=pred_rotmat_body[:, 1:] + ) pred_gl_lelbow = pred_gl_body[:, 18] pred_gl_relbow = pred_gl_body[:, 19] target_gl_lwrist = rot6d_to_rotmat( - pred_orient_lh.reshape(batch_size, -1, 6)) + pred_orient_lh.reshape(batch_size, -1, 6) + ) target_gl_lwrist *= self.flip_vector.to(target_gl_lwrist.device) target_gl_rwrist = rot6d_to_rotmat( - pred_orient_rh.reshape(batch_size, -1, 6)) + pred_orient_rh.reshape(batch_size, -1, 6) + ) opt_lwrist = torch.bmm(pred_gl_lelbow.transpose(1, 2), target_gl_lwrist) opt_rwrist = torch.bmm(pred_gl_relbow.transpose(1, 2), target_gl_rwrist) @@ -408,34 +428,40 @@ class Regressor(nn.Module): if cfg.MODEL.PyMAF.ADAPT_INTEGR: # if cfg.MODEL.PyMAF.ADAPT_INTEGR and global_iter == (cfg.MODEL.PyMAF.N_ITER - 1): tpose_joints = self.smpl.get_tpose(betas=pred_shape) - lelbow_twist_axis = nn.functional.normalize(tpose_joints[:, 20] - - tpose_joints[:, 18], - dim=1) - relbow_twist_axis = nn.functional.normalize(tpose_joints[:, 21] - - tpose_joints[:, 19], - dim=1) + lelbow_twist_axis = nn.functional.normalize( + tpose_joints[:, 20] - tpose_joints[:, 18], dim=1 + ) + relbow_twist_axis = nn.functional.normalize( + tpose_joints[:, 21] - tpose_joints[:, 19], dim=1 + ) lelbow_twist, lelbow_twist_angle = compute_twist_rotation( - opt_lwrist, lelbow_twist_axis) + opt_lwrist, lelbow_twist_axis + ) relbow_twist, relbow_twist_angle = compute_twist_rotation( - opt_rwrist, relbow_twist_axis) + opt_rwrist, relbow_twist_axis + ) min_angle = -0.4 * float(np.pi) max_angle = 0.4 * float(np.pi) - lelbow_twist_angle[lelbow_twist_angle == torch.clamp( - lelbow_twist_angle, min_angle, max_angle)] = 0 - relbow_twist_angle[relbow_twist_angle == torch.clamp( - relbow_twist_angle, min_angle, max_angle)] = 0 + lelbow_twist_angle[lelbow_twist_angle == torch. + clamp(lelbow_twist_angle, min_angle, max_angle) + ] = 0 + relbow_twist_angle[relbow_twist_angle == torch. + clamp(relbow_twist_angle, min_angle, max_angle) + ] = 0 lelbow_twist_angle[lelbow_twist_angle > max_angle] -= max_angle lelbow_twist_angle[lelbow_twist_angle < min_angle] -= min_angle relbow_twist_angle[relbow_twist_angle > max_angle] -= max_angle relbow_twist_angle[relbow_twist_angle < min_angle] -= min_angle - lelbow_twist = batch_rodrigues(lelbow_twist_axis * - lelbow_twist_angle) - relbow_twist = batch_rodrigues(relbow_twist_axis * - relbow_twist_angle) + lelbow_twist = batch_rodrigues( + lelbow_twist_axis * lelbow_twist_angle + ) + relbow_twist = batch_rodrigues( + relbow_twist_axis * relbow_twist_angle + ) opt_lwrist = torch.bmm(lelbow_twist.transpose(1, 2), opt_lwrist) opt_rwrist = torch.bmm(relbow_twist.transpose(1, 2), opt_rwrist) @@ -446,7 +472,8 @@ class Regressor(nn.Module): opt_relbow = torch.bmm(pred_rotmat_body[:, 19], relbow_twist) if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == ( - cfg.MODEL.PyMAF.N_ITER - 1): + cfg.MODEL.PyMAF.N_ITER - 1 + ): opt_lwrist_filtered = [ opt_lwrist[_i] if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20] @@ -473,16 +500,19 @@ class Regressor(nn.Module): opt_lelbow = torch.stack(opt_lelbow_filtered) opt_relbow = torch.stack(opt_relbow_filtered) - pred_rotmat_body = torch.cat([ - pred_rotmat_body[:, :18], - opt_lelbow.unsqueeze(1), - opt_relbow.unsqueeze(1), - opt_lwrist.unsqueeze(1), - opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:] - ], 1) + pred_rotmat_body = torch.cat( + [ + pred_rotmat_body[:, :18], + opt_lelbow.unsqueeze(1), + opt_relbow.unsqueeze(1), + opt_lwrist.unsqueeze(1), + opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:] + ], 1 + ) else: if cfg.MODEL.PyMAF.PRED_VIS_H and global_iter == ( - cfg.MODEL.PyMAF.N_ITER - 1): + cfg.MODEL.PyMAF.N_ITER - 1 + ): opt_lwrist_filtered = [ opt_lwrist[_i] if pred_vis_lhand[_i] else pred_rotmat_body[_i, 20] @@ -497,32 +527,36 @@ class Regressor(nn.Module): opt_lwrist = torch.stack(opt_lwrist_filtered) opt_rwrist = torch.stack(opt_rwrist_filtered) - pred_rotmat_body = torch.cat([ - pred_rotmat_body[:, :20], - opt_lwrist.unsqueeze(1), - opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:] - ], 1) + pred_rotmat_body = torch.cat( + [ + pred_rotmat_body[:, :20], + opt_lwrist.unsqueeze(1), + opt_rwrist.unsqueeze(1), pred_rotmat_body[:, 22:] + ], 1 + ) if self.hand_only_mode: pred_rotmat_rh = rot6d_to_rotmat( - torch.cat([pred_orient_rh, pred_rhand], - dim=1).reshape(batch_size, -1, 6)) # .view(batch_size, 16, 3, 3) + torch.cat([pred_orient_rh, pred_rhand], dim=1).reshape(batch_size, -1, 6) + ) # .view(batch_size, 16, 3, 3) assert pred_rotmat_rh.shape[1] == 1 + 15 elif self.face_only_mode: pred_rotmat_fa = rot6d_to_rotmat( - torch.cat([pred_orient_fa, pred_face], - dim=1).reshape(batch_size, -1, 6)) # .view(batch_size, 16, 3, 3) + torch.cat([pred_orient_fa, pred_face], dim=1).reshape(batch_size, -1, 6) + ) # .view(batch_size, 16, 3, 3) assert pred_rotmat_fa.shape[1] == 1 + 3 elif self.full_body_mode or self.body_hand_mode: if cfg.MODEL.PyMAF.OPT_WRIST: pred_rotmat = pred_rotmat_body else: - pred_rotmat = rot6d_to_rotmat(pred_pose.reshape(batch_size, -1, - 6)) # .view(batch_size, 24, 3, 3) + pred_rotmat = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) assert pred_rotmat.shape[1] == 24 else: - pred_rotmat = rot6d_to_rotmat(pred_pose.reshape(batch_size, -1, - 6)) # .view(batch_size, 24, 3, 3) + pred_rotmat = rot6d_to_rotmat( + pred_pose.reshape(batch_size, -1, 6) + ) # .view(batch_size, 24, 3, 3) assert pred_rotmat.shape[1] == 24 # if self.full_body_mode: @@ -547,8 +581,8 @@ class Regressor(nn.Module): assert pred_hfrotmat.shape[1] == (15 * 2 + 3) # flip left hand pose - pred_lhand_rotmat = pred_hfrotmat[:, :15] * self.flip_vector.to( - pred_hfrotmat.device).unsqueeze(0) + pred_lhand_rotmat = pred_hfrotmat[:, :15] * self.flip_vector.to(pred_hfrotmat.device + ).unsqueeze(0) pred_rhand_rotmat = pred_hfrotmat[:, 15:30] pred_face_rotmat = pred_hfrotmat[:, 30:] @@ -596,17 +630,20 @@ class Regressor(nn.Module): elif self.face_only_mode: pred_joints_full = pred_output.face_joints elif self.smplx_mode: - pred_joints_full = torch.cat([ - pred_joints, pred_output.lhand_joints, pred_output.rhand_joints, - pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints - ], - dim=1) + pred_joints_full = torch.cat( + [ + pred_joints, pred_output.lhand_joints, pred_output.rhand_joints, + pred_output.face_joints, pred_output.lfoot_joints, pred_output.rfoot_joints + ], + dim=1 + ) else: pred_joints_full = pred_joints - pred_keypoints_2d = projection(pred_joints_full, { - **rw_cam, 'cam_sxy': pred_cam - }, - iwp_mode=cfg.MODEL.USE_IWP_CAM) + pred_keypoints_2d = projection( + pred_joints_full, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) if cfg.MODEL.USE_IWP_CAM: # Normalize keypoints to [-1,1] pred_keypoints_2d = pred_keypoints_2d / (224. / 2.) @@ -624,126 +661,137 @@ class Regressor(nn.Module): else: kp_3d = pred_joints pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72) - output.update({ - 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1), - 'verts': pred_vertices, - 'kp_2d': pred_keypoints_2d[:, :len_b_kp], - 'kp_3d': kp_3d, - 'pred_joints': pred_joints, - 'smpl_kp_3d': pred_output.smpl_joints, - 'rotmat': pred_rotmat, - 'pred_cam': pred_cam, - 'pred_shape': pred_shape, - 'pred_pose': pred_pose, - }) + output.update( + { + 'theta': torch.cat([pred_cam, pred_shape, pose], dim=1), + 'verts': pred_vertices, + 'kp_2d': pred_keypoints_2d[:, :len_b_kp], + 'kp_3d': kp_3d, + 'pred_joints': pred_joints, + 'smpl_kp_3d': pred_output.smpl_joints, + 'rotmat': pred_rotmat, + 'pred_cam': pred_cam, + 'pred_shape': pred_shape, + 'pred_pose': pred_pose, + } + ) # if self.full_body_mode: if self.smplx_mode: # assert pred_keypoints_2d.shape[1] == 144 len_h_kp = len(constants.HAND_NAMES) len_f_kp = len(constants.FACIAL_LANDMARKS) len_feet_kp = 2 * len(constants.FOOT_NAMES) - output.update({ - 'smplx_verts': - pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None, - 'pred_lhand': - pred_lhand, - 'pred_rhand': - pred_rhand, - 'pred_face': - pred_face, - 'pred_exp': - pred_exp, - 'verts_lh': - pred_output.lhand_vertices, - 'verts_rh': - pred_output.rhand_vertices, - # 'pred_arm_rotmat': pred_arm_rotmat, - # 'pred_hfrotmat': pred_hfrotmat, - 'pred_lhand_rotmat': - pred_lhand_rotmat, - 'pred_rhand_rotmat': - pred_rhand_rotmat, - 'pred_face_rotmat': - pred_face_rotmat, - 'pred_lhand_kp3d': - pred_output.lhand_joints, - 'pred_rhand_kp3d': - pred_output.rhand_joints, - 'pred_face_kp3d': - pred_output.face_joints, - 'pred_lhand_kp2d': - pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp], - 'pred_rhand_kp2d': - pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2], - 'pred_face_kp2d': - pred_keypoints_2d[:, len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 + - len_f_kp], - 'pred_feet_kp2d': - pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp + - len_h_kp * 2 + len_f_kp + len_feet_kp], - }) + output.update( + { + 'smplx_verts': + pred_output.smplx_vertices if cfg.MODEL.EVAL_MODE else None, + 'pred_lhand': + pred_lhand, + 'pred_rhand': + pred_rhand, + 'pred_face': + pred_face, + 'pred_exp': + pred_exp, + 'verts_lh': + pred_output.lhand_vertices, + 'verts_rh': + pred_output.rhand_vertices, + # 'pred_arm_rotmat': pred_arm_rotmat, + # 'pred_hfrotmat': pred_hfrotmat, + 'pred_lhand_rotmat': + pred_lhand_rotmat, + 'pred_rhand_rotmat': + pred_rhand_rotmat, + 'pred_face_rotmat': + pred_face_rotmat, + 'pred_lhand_kp3d': + pred_output.lhand_joints, + 'pred_rhand_kp3d': + pred_output.rhand_joints, + 'pred_face_kp3d': + pred_output.face_joints, + 'pred_lhand_kp2d': + pred_keypoints_2d[:, len_b_kp:len_b_kp + len_h_kp], + 'pred_rhand_kp2d': + pred_keypoints_2d[:, len_b_kp + len_h_kp:len_b_kp + len_h_kp * 2], + 'pred_face_kp2d': + pred_keypoints_2d[:, len_b_kp + len_h_kp * 2:len_b_kp + len_h_kp * 2 + + len_f_kp], + 'pred_feet_kp2d': + pred_keypoints_2d[:, len_b_kp + len_h_kp * 2 + len_f_kp:len_b_kp + + len_h_kp * 2 + len_f_kp + len_feet_kp], + } + ) if cfg.MODEL.PyMAF.OPT_WRIST: - output.update({ - 'pred_orient_lh': pred_orient_lh, - 'pred_shape_lh': pred_shape_lh, - 'pred_orient_rh': pred_orient_rh, - 'pred_shape_rh': pred_shape_rh, - 'pred_cam_fa': pred_cam_fa, - 'pred_cam_lh': pred_cam_lh, - 'pred_cam_rh': pred_cam_rh, - }) + output.update( + { + 'pred_orient_lh': pred_orient_lh, + 'pred_shape_lh': pred_shape_lh, + 'pred_orient_rh': pred_orient_rh, + 'pred_shape_rh': pred_shape_rh, + 'pred_cam_fa': pred_cam_fa, + 'pred_cam_lh': pred_cam_lh, + 'pred_cam_rh': pred_cam_rh, + } + ) if cfg.MODEL.PyMAF.PRED_VIS_H: output.update({'pred_vis_hands': pred_vis_hands}) elif self.hand_only_mode: # hand mesh out assert pred_keypoints_2d.shape[1] == 21 - output.update({ - 'theta': pred_cam, - 'pred_cam': pred_cam, - 'pred_rhand': pred_rhand, - 'pred_rhand_rotmat': pred_rotmat_rh[:, 1:], - 'pred_orient_rh': pred_orient_rh, - 'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0], - 'verts_rh': pred_output.rhand_vertices, - 'pred_cam_rh': pred_cam_rh, - 'pred_shape_rh': pred_shape_rh, - 'pred_rhand_kp3d': pred_output.rhand_joints, - 'pred_rhand_kp2d': pred_keypoints_2d, - }) + output.update( + { + 'theta': pred_cam, + 'pred_cam': pred_cam, + 'pred_rhand': pred_rhand, + 'pred_rhand_rotmat': pred_rotmat_rh[:, 1:], + 'pred_orient_rh': pred_orient_rh, + 'pred_orient_rh_rotmat': pred_rotmat_rh[:, 0], + 'verts_rh': pred_output.rhand_vertices, + 'pred_cam_rh': pred_cam_rh, + 'pred_shape_rh': pred_shape_rh, + 'pred_rhand_kp3d': pred_output.rhand_joints, + 'pred_rhand_kp2d': pred_keypoints_2d, + } + ) elif self.face_only_mode: # face mesh out assert pred_keypoints_2d.shape[1] == 68 - output.update({ - 'theta': pred_cam, - 'pred_cam': pred_cam, - 'pred_face': pred_face, - 'pred_exp': pred_exp, - 'pred_face_rotmat': pred_rotmat_fa[:, 1:], - 'pred_orient_fa': pred_orient_fa, - 'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0], - 'verts_fa': pred_output.flame_vertices, - 'pred_cam_fa': pred_cam_fa, - 'pred_shape_fa': pred_shape_fa, - 'pred_face_kp3d': pred_output.face_joints, - 'pred_face_kp2d': pred_keypoints_2d, - }) + output.update( + { + 'theta': pred_cam, + 'pred_cam': pred_cam, + 'pred_face': pred_face, + 'pred_exp': pred_exp, + 'pred_face_rotmat': pred_rotmat_fa[:, 1:], + 'pred_orient_fa': pred_orient_fa, + 'pred_orient_fa_rotmat': pred_rotmat_fa[:, 0], + 'verts_fa': pred_output.flame_vertices, + 'pred_cam_fa': pred_cam_fa, + 'pred_shape_fa': pred_shape_fa, + 'pred_face_kp3d': pred_output.face_joints, + 'pred_face_kp2d': pred_keypoints_2d, + } + ) return output -def get_attention_modules(module_keys, - img_feature_dim_list, - hidden_feat_dim, - n_iter, - num_attention_heads=1): +def get_attention_modules( + module_keys, img_feature_dim_list, hidden_feat_dim, n_iter, num_attention_heads=1 +): align_attention = nn.ModuleDict() for k in module_keys: align_attention[k] = nn.ModuleList() for i in range(n_iter): align_attention[k].append( - get_att_block(img_feature_dim=img_feature_dim_list[k][i], - hidden_feat_dim=hidden_feat_dim, - num_attention_heads=num_attention_heads)) + get_att_block( + img_feature_dim=img_feature_dim_list[k][i], + hidden_feat_dim=hidden_feat_dim, + num_attention_heads=num_attention_heads + ) + ) return align_attention @@ -764,11 +812,9 @@ class PyMAF(nn.Module): PyMAF: 3D Human Pose and Shape Regression with Pyramidal Mesh Alignment Feedback Loop, in ICCV, 2021 PyMAF-X: Towards Well-aligned Full-body Model Regression from Monocular Images, arXiv:2207.06400, 2022 """ - - def __init__(self, - smpl_mean_params=SMPL_MEAN_PARAMS, - pretrained=True, - device=torch.device('cuda')): + def __init__( + self, smpl_mean_params=SMPL_MEAN_PARAMS, pretrained=True, device=torch.device('cuda') + ): super().__init__() self.device = device @@ -829,8 +875,9 @@ class PyMAF(nn.Module): self.smpl_family['face'] = SMPL_Family(model_type='flame') self.smpl_family['body'] = SMPL_Family(model_type='smplx') else: - self.smpl_family['body'] = SMPL_Family(model_type=cfg.MODEL.MESH_MODEL, - all_gender=cfg.MODEL.ALL_GENDER) + self.smpl_family['body'] = SMPL_Family( + model_type=cfg.MODEL.MESH_MODEL, all_gender=cfg.MODEL.ALL_GENDER + ) self.init_mesh_output = None self.batch_size = 1 @@ -845,14 +892,14 @@ class PyMAF(nn.Module): if 'body' in bhf_names: # if self.smplx_mode or 'hr' in cfg.MODEL.PyMAF.BACKBONE: if cfg.MODEL.PyMAF.BACKBONE == 'res50': - body_encoder = get_resnet_encoder(cfg, - init_weight=(not cfg.MODEL.EVAL_MODE), - global_mode=self.global_mode) + body_encoder = get_resnet_encoder( + cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode + ) body_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS) elif cfg.MODEL.PyMAF.BACKBONE == 'hr48': - body_encoder = get_hrnet_encoder(cfg, - init_weight=(not cfg.MODEL.EVAL_MODE), - global_mode=self.global_mode) + body_encoder = get_hrnet_encoder( + cfg, init_weight=(not cfg.MODEL.EVAL_MODE), global_mode=self.global_mode + ) body_sfeat_dim = list(cfg.HR_MODEL.EXTRA.STAGE4.NUM_CHANNELS) body_sfeat_dim.reverse() body_sfeat_dim = body_sfeat_dim[1:] @@ -885,7 +932,8 @@ class PyMAF(nn.Module): self.encoders[hf] = get_resnet_encoder( cfg, init_weight=(not cfg.MODEL.EVAL_MODE), - global_mode=self.global_mode) + global_mode=self.global_mode + ) self.part_module_names[hf].update({f'encoders.{hf}': self.encoders[hf]}) hf_sfeat_dim = list(cfg.POSE_RES_MODEL.EXTRA.NUM_DECONV_FILTERS) else: @@ -895,15 +943,19 @@ class PyMAF(nn.Module): assert cfg.MODEL.PyMAF.MAF_ON self.dp_head_hf = nn.ModuleDict() if 'hand' in bhf_names: - self.dp_head_hf['hand'] = IUV_predict_layer(feat_dim=hf_sfeat_dim[-1], - mode='pncc') + self.dp_head_hf['hand'] = IUV_predict_layer( + feat_dim=hf_sfeat_dim[-1], mode='pncc' + ) self.part_module_names['hand'].update( - {'dp_head_hf.hand': self.dp_head_hf['hand']}) + {'dp_head_hf.hand': self.dp_head_hf['hand']} + ) if 'face' in bhf_names: - self.dp_head_hf['face'] = IUV_predict_layer(feat_dim=hf_sfeat_dim[-1], - mode='pncc') + self.dp_head_hf['face'] = IUV_predict_layer( + feat_dim=hf_sfeat_dim[-1], mode='pncc' + ) self.part_module_names['face'].update( - {'dp_head_hf.face': self.dp_head_hf['face']}) + {'dp_head_hf.face': self.dp_head_hf['face']} + ) smpl2limb_vert_faces = get_partial_smpl() @@ -914,7 +966,8 @@ class PyMAF(nn.Module): grid_size = 21 xv, yv = torch.meshgrid( [torch.linspace(-1, 1, grid_size), - torch.linspace(-1, 1, grid_size)]) + torch.linspace(-1, 1, grid_size)] + ) grid_points = torch.stack([xv.reshape(-1), yv.reshape(-1)]).unsqueeze(0) self.register_buffer('grid_points', grid_points) grid_feat_dim = grid_size * grid_size * cfg.MODEL.PyMAF.MLP_DIM[-1] @@ -943,7 +996,8 @@ class PyMAF(nn.Module): if 'face' in self.bhf_names: bhf_ma_feat_dim.update( - {'face': len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]}) + {'face': len(constants.FACIAL_LANDMARKS) * cfg.MODEL.PyMAF.HF_MLP_DIM[-1]} + ) if self.fuse_grid_align: bhf_att_feat_dim.update({'face': 1024}) @@ -959,25 +1013,31 @@ class PyMAF(nn.Module): if 'face' in bhf_names: hfimg_feat_dim_list['face'] = hf_sfeat_dim[-n_iter_att:] - self.align_attention = get_attention_modules(bhf_names, - hfimg_feat_dim_list, - hidden_feat_dim, - n_iter=n_iter_att, - num_attention_heads=num_att_heads) + self.align_attention = get_attention_modules( + bhf_names, + hfimg_feat_dim_list, + hidden_feat_dim, + n_iter=n_iter_att, + num_attention_heads=num_att_heads + ) for part in bhf_names: self.part_module_names[part].update( - {f'align_attention.{part}': self.align_attention[part]}) + {f'align_attention.{part}': self.align_attention[part]} + ) if self.fuse_grid_align: - self.att_feat_reduce = get_fusion_modules(bhf_names, - bhf_ma_feat_dim, - grid_feat_dim, - n_iter=n_iter_att, - out_feat_len=bhf_att_feat_dim) + self.att_feat_reduce = get_fusion_modules( + bhf_names, + bhf_ma_feat_dim, + grid_feat_dim, + n_iter=n_iter_att, + out_feat_len=bhf_att_feat_dim + ) for part in bhf_names: self.part_module_names[part].update( - {f'att_feat_reduce.{part}': self.att_feat_reduce[part]}) + {f'att_feat_reduce.{part}': self.att_feat_reduce[part]} + ) # build regressor for parameter prediction self.regressor = nn.ModuleList() @@ -1002,10 +1062,13 @@ class PyMAF(nn.Module): if self.smpl_mode: self.regressor.append( - Regressor(feat_dim=ref_infeat_dim, - smpl_mean_params=smpl_mean_params, - use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT, - smpl_models=self.smpl_family)) + Regressor( + feat_dim=ref_infeat_dim, + smpl_mean_params=smpl_mean_params, + use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT, + smpl_models=self.smpl_family + ) + ) else: if cfg.MODEL.PyMAF.MAF_ON: if 'hand' in self.bhf_names or 'face' in self.bhf_names: @@ -1032,28 +1095,35 @@ class PyMAF(nn.Module): feat_dim_face = global_feat_dim self.regressor.append( - Regressor(feat_dim=ref_infeat_dim, - smpl_mean_params=smpl_mean_params, - use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT, - feat_dim_hand=feat_dim_hand, - feat_dim_face=feat_dim_face, - bhf_names=bhf_names, - smpl_models=self.smpl_family)) + Regressor( + feat_dim=ref_infeat_dim, + smpl_mean_params=smpl_mean_params, + use_cam_feats=cfg.MODEL.PyMAF.USE_CAM_FEAT, + feat_dim_hand=feat_dim_hand, + feat_dim_face=feat_dim_face, + bhf_names=bhf_names, + smpl_models=self.smpl_family + ) + ) # assign sub-regressor to each part for dec_name, dec_module in self.regressor[-1].named_children(): if 'hand' in dec_name: self.part_module_names['hand'].update( - {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}) + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) elif 'face' in dec_name or 'head' in dec_name or 'exp' in dec_name: self.part_module_names['face'].update( - {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}) + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) elif 'res' in dec_name or 'vis' in dec_name: self.part_module_names['link'].update( - {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}) + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) elif 'body' in self.part_module_names: self.part_module_names['body'].update( - {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module}) + {'regressor.{}.{}.'.format(len(self.regressor) - 1, dec_name): dec_module} + ) # mesh-aligned feature extractor self.maf_extractor = nn.ModuleDict() @@ -1070,12 +1140,17 @@ class PyMAF(nn.Module): if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT and i >= self.att_starts: self.maf_extractor[part].append( - MAF_Extractor(filter_channels=filter_channels_default[att_feat_dim_idx:], - iwp_cam_mode=cfg.MODEL.USE_IWP_CAM)) + MAF_Extractor( + filter_channels=filter_channels_default[att_feat_dim_idx:], + iwp_cam_mode=cfg.MODEL.USE_IWP_CAM + ) + ) else: self.maf_extractor[part].append( - MAF_Extractor(filter_channels=filter_channels, - iwp_cam_mode=cfg.MODEL.USE_IWP_CAM)) + MAF_Extractor( + filter_channels=filter_channels, iwp_cam_mode=cfg.MODEL.USE_IWP_CAM + ) + ) self.part_module_names[part].update({f'maf_extractor.{part}': self.maf_extractor[part]}) # check all modules have been added to part_module_names @@ -1099,10 +1174,9 @@ class PyMAF(nn.Module): """ initialize the mesh model with default poses and shapes """ if self.init_mesh_output is None or self.batch_size != batch_size: - self.init_mesh_output = self.regressor[0](torch.zeros(batch_size), - J_regressor=J_regressor, - rw_cam=rw_cam, - init_mode=True) + self.init_mesh_output = self.regressor[0]( + torch.zeros(batch_size), J_regressor=J_regressor, rw_cam=rw_cam, init_mode=True + ) self.batch_size = batch_size return self.init_mesh_output @@ -1110,11 +1184,13 @@ class PyMAF(nn.Module): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, - planes * block.expansion, - kernel_size=1, - stride=stride, - bias=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), nn.BatchNorm2d(planes * block.expansion), ) @@ -1156,13 +1232,16 @@ class PyMAF(nn.Module): planes = num_filters[i] layers.append( - nn.ConvTranspose2d(in_channels=self.inplanes, - out_channels=planes, - kernel_size=kernel, - stride=2, - padding=padding, - output_padding=output_padding, - bias=self.deconv_with_bias)) + nn.ConvTranspose2d( + in_channels=self.inplanes, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=self.deconv_with_bias + ) + ) layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) layers.append(nn.ReLU(inplace=True)) self.inplanes = planes @@ -1196,7 +1275,7 @@ class PyMAF(nn.Module): vis_feat_list: the list containing features for visualization ''' - # batch keys: ['img_body', 'orig_height', 'orig_width', 'person_id', 'img_lhand', + # batch keys: ['img_body', 'orig_height', 'orig_width', 'person_id', 'img_lhand', # 'lhand_theta_inv', 'img_rhand', 'rhand_theta_inv', 'img_face', 'face_theta_inv'] # extract spatial features or global features @@ -1234,7 +1313,8 @@ class PyMAF(nn.Module): img_rhand = batch['img_rhand'] batch_size = img_rhand.shape[0] limb_feat_dict['rhand'], limb_gfeat_dict['rhand'] = self.encoders['hand']( - img_rhand) + img_rhand + ) if cfg.MODEL.PyMAF.MAF_ON: for k in limb_feat_dict.keys(): @@ -1292,10 +1372,11 @@ class PyMAF(nn.Module): if self.hand_only_mode: pred_cam = mesh_output['pred_cam'].detach() pred_rhand_v = self.mano_sampler(mesh_output['verts_rh']) - pred_rhand_proj = projection(pred_rhand_v, { - **rw_cam, 'cam_sxy': pred_cam - }, - iwp_mode=cfg.MODEL.USE_IWP_CAM) + pred_rhand_proj = projection( + pred_rhand_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) if cfg.MODEL.USE_IWP_CAM: pred_rhand_proj = pred_rhand_proj / (224. / 2.) else: @@ -1310,10 +1391,11 @@ class PyMAF(nn.Module): elif self.face_only_mode: pred_cam = mesh_output['pred_cam'].detach() pred_face_v = mesh_output['pred_face_kp3d'] - pred_face_proj = projection(pred_face_v, { - **rw_cam, 'cam_sxy': pred_cam - }, - iwp_mode=cfg.MODEL.USE_IWP_CAM) + pred_face_proj = projection( + pred_face_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) if cfg.MODEL.USE_IWP_CAM: pred_face_proj = pred_face_proj / (224. / 2.) else: @@ -1326,10 +1408,11 @@ class PyMAF(nn.Module): pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand]) pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand]) pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1) - pred_hand_proj = projection(pred_hand_v, { - **rw_cam, 'cam_sxy': pred_cam - }, - iwp_mode=cfg.MODEL.USE_IWP_CAM) + pred_hand_proj = projection( + pred_hand_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) if cfg.MODEL.USE_IWP_CAM: pred_hand_proj = pred_hand_proj / (224. / 2.) else: @@ -1343,20 +1426,23 @@ class PyMAF(nn.Module): } proj_hf_pts = { 'lhand': - torch.cat([proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], - dim=1), + torch.cat( + [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1 + ), 'rhand': - torch.cat([proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], - dim=1), + torch.cat( + [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1 + ), } elif self.full_body_mode: pred_lhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2lhand]) pred_rhand_v = self.mano_sampler(pred_smpl_verts[:, self.smpl2rhand]) pred_hand_v = torch.cat([pred_lhand_v, pred_rhand_v], dim=1) - pred_hand_proj = projection(pred_hand_v, { - **rw_cam, 'cam_sxy': pred_cam - }, - iwp_mode=cfg.MODEL.USE_IWP_CAM) + pred_hand_proj = projection( + pred_hand_v, { + **rw_cam, 'cam_sxy': pred_cam + }, iwp_mode=cfg.MODEL.USE_IWP_CAM + ) if cfg.MODEL.USE_IWP_CAM: pred_hand_proj = pred_hand_proj / (224. / 2.) else: @@ -1372,11 +1458,13 @@ class PyMAF(nn.Module): } proj_hf_pts = { 'lhand': - torch.cat([proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], - dim=1), + torch.cat( + [proj_hf_center['lhand'], pred_hand_proj[:, :self.mano_ds_len]], dim=1 + ), 'rhand': - torch.cat([proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], - dim=1), + torch.cat( + [proj_hf_center['rhand'], pred_hand_proj[:, self.mano_ds_len:]], dim=1 + ), 'face': torch.cat([proj_hf_center['face'], mesh_output['pred_face_kp2d']], dim=1) } @@ -1402,7 +1490,8 @@ class PyMAF(nn.Module): if limb_rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT: limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling( - grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim) + grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim + ) else: if self.hand_only_mode or self.face_only_mode: proj_hf_pts_crop = proj_hf_pts[part_name][:, :, :2] @@ -1422,8 +1511,8 @@ class PyMAF(nn.Module): theta_i_inv = batch[f'{part_name}_theta_inv'] proj_hf_pts_crop = torch.bmm( theta_i_inv, - homo_vector(proj_hf_pts[part_name][:, :, :2]).permute( - 0, 2, 1)).permute(0, 2, 1) + homo_vector(proj_hf_pts[part_name][:, :, :2]).permute(0, 2, 1) + ).permute(0, 2, 1) if part_name == 'lhand': flip_x = torch.tensor([-1, 1])[None, @@ -1445,15 +1534,17 @@ class PyMAF(nn.Module): limb_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling( proj_hf_pts_crop_ctd.detach(), im_feat=limb_feat_i, - reduce_dim=limb_reduce_dim) + reduce_dim=limb_reduce_dim + ) if self.fuse_grid_align and limb_rf_i >= self.att_starts: limb_grid_feature_ctd = self.maf_extractor[hf_key][limb_rf_i].sampling( - grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim) + grid_points, im_feat=limb_feat_i, reduce_dim=limb_reduce_dim + ) limb_grid_ref_feat_ctd = torch.cat( - [limb_grid_feature_ctd, limb_ref_feat_ctd], - dim=-1).permute(0, 2, 1) + [limb_grid_feature_ctd, limb_ref_feat_ctd], dim=-1 + ).permute(0, 2, 1) if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT: att_ref_feat_ctd = self.align_attention[hf_key][ @@ -1462,7 +1553,8 @@ class PyMAF(nn.Module): att_ref_feat_ctd = limb_grid_ref_feat_ctd att_ref_feat_ctd = self.maf_extractor[hf_key][limb_rf_i].reduce_dim( - att_ref_feat_ctd.permute(0, 2, 1)).view(batch_size, -1) + att_ref_feat_ctd.permute(0, 2, 1) + ).view(batch_size, -1) limb_ref_feat_ctd = self.att_feat_reduce[hf_key][ limb_rf_i - self.att_starts](att_ref_feat_ctd) @@ -1479,11 +1571,13 @@ class PyMAF(nn.Module): reduce_dim = (not self.fuse_grid_align) or (rf_i < self.att_starts) if rf_i == 0 or cfg.MODEL.PyMAF.GRID_FEAT: ref_feature = self.maf_extractor['body'][rf_i].sampling( - grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim) + grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim + ) else: # TODO: use a more sparse SMPL implementation (with 431 vertices) for acceleration pred_smpl_verts_ds = self.mesh_sampler.downsample( - pred_smpl_verts) # [B, 431, 3] + pred_smpl_verts + ) # [B, 431, 3] ref_feature = self.maf_extractor['body'][rf_i]( pred_smpl_verts_ds, im_feat=s_feat_i, @@ -1491,25 +1585,28 @@ class PyMAF(nn.Module): **rw_cam, 'cam_sxy': pred_cam }, add_att=True, - reduce_dim=reduce_dim) # [B, 431 * n_feat] + reduce_dim=reduce_dim + ) # [B, 431 * n_feat] if self.fuse_grid_align and rf_i >= self.att_starts: if rf_i > 0 and not cfg.MODEL.PyMAF.GRID_FEAT: grid_feature = self.maf_extractor['body'][rf_i].sampling( - grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim) + grid_points, im_feat=s_feat_i, reduce_dim=reduce_dim + ) grid_ref_feat = torch.cat([grid_feature, ref_feature], dim=-1) else: grid_ref_feat = ref_feature grid_ref_feat = grid_ref_feat.permute(0, 2, 1) if cfg.MODEL.PyMAF.GRID_ALIGN.USE_ATT: - att_ref_feat = self.align_attention['body'][rf_i - self.att_starts]( - grid_ref_feat)[0] + att_ref_feat = self.align_attention['body'][ + rf_i - self.att_starts](grid_ref_feat)[0] elif cfg.MODEL.PyMAF.GRID_ALIGN.USE_FC: att_ref_feat = grid_ref_feat att_ref_feat = self.maf_extractor['body'][rf_i].reduce_dim( - att_ref_feat.permute(0, 2, 1)) + att_ref_feat.permute(0, 2, 1) + ) att_ref_feat = att_ref_feat.view(batch_size, -1) ref_feature = self.att_feat_reduce['body'][rf_i - @@ -1560,12 +1657,14 @@ class PyMAF(nn.Module): current_states['init_cam_rh'] = mesh_output['pred_cam_rh'].detach() # update mesh parameters - mesh_output = self.regressor[rf_i](ref_feature, - n_iter=1, - J_regressor=J_regressor, - rw_cam=rw_cam, - global_iter=rf_i, - **current_states) + mesh_output = self.regressor[rf_i]( + ref_feature, + n_iter=1, + J_regressor=J_regressor, + rw_cam=rw_cam, + global_iter=rf_i, + **current_states + ) out_dict['mesh_out'].append(mesh_output) diff --git a/lib/pymafx/models/res_module.py b/lib/pymafx/models/res_module.py index 98d7721d8562110472e4de730028a5ff6da6c0e7..94de7ecaa2ba3ead51c5f960e0ae08b806d9cd80 100644 --- a/lib/pymafx/models/res_module.py +++ b/lib/pymafx/models/res_module.py @@ -12,17 +12,24 @@ from collections import OrderedDict from lib.pymafx.core.cfgs import cfg # from .transformers.tokenlearner import TokenLearner - import logging -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) BN_MOMENTUM = 0.1 + def conv3x3(in_planes, out_planes, stride=1, bias=False, groups=1): """3x3 convolution with padding""" - return nn.Conv2d(in_planes * groups, out_planes * groups, kernel_size=3, stride=stride, - padding=1, bias=bias, groups=groups) + return nn.Conv2d( + in_planes * groups, + out_planes * groups, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + groups=groups + ) class BasicBlock(nn.Module): @@ -62,15 +69,28 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1): super().__init__() - self.conv1 = nn.Conv2d(inplanes * groups, planes * groups, kernel_size=1, bias=False, groups=groups) + self.conv1 = nn.Conv2d( + inplanes * groups, planes * groups, kernel_size=1, bias=False, groups=groups + ) self.bn1 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) - self.conv2 = nn.Conv2d(planes * groups, planes * groups, kernel_size=3, stride=stride, - padding=1, bias=False, groups=groups) + self.conv2 = nn.Conv2d( + planes * groups, + planes * groups, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=groups + ) self.bn2 = nn.BatchNorm2d(planes * groups, momentum=BN_MOMENTUM) - self.conv3 = nn.Conv2d(planes * groups, planes * self.expansion * groups, kernel_size=1, - bias=False, groups=groups) - self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups, - momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes * groups, + planes * self.expansion * groups, + kernel_size=1, + bias=False, + groups=groups + ) + self.bn3 = nn.BatchNorm2d(planes * self.expansion * groups, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -98,11 +118,13 @@ class Bottleneck(nn.Module): return out -resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]), - 34: (BasicBlock, [3, 4, 6, 3]), - 50: (Bottleneck, [3, 4, 6, 3]), - 101: (Bottleneck, [3, 4, 23, 3]), - 152: (Bottleneck, [3, 8, 36, 3])} +resnet_spec = { + 18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3]) +} class IUV_predict_layer(nn.Module): @@ -162,12 +184,12 @@ class IUV_predict_layer(nn.Module): ) elif mode in ['pncc']: self.predict_pncc = nn.Conv2d( - in_channels=feat_dim, - out_channels=3, - kernel_size=final_cov_k, - stride=1, - padding=1 if final_cov_k == 3 else 0 - ) + in_channels=feat_dim, + out_channels=3, + kernel_size=final_cov_k, + stride=1, + padding=1 if final_cov_k == 3 else 0 + ) self.inplanes = feat_dim @@ -175,8 +197,13 @@ class IUV_predict_layer(nn.Module): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), nn.BatchNorm2d(planes * block.expansion), ) @@ -197,7 +224,6 @@ class IUV_predict_layer(nn.Module): return_dict['predict_uv_index'] = predict_uv_index return_dict['predict_ann_index'] = predict_ann_index - if self.mode == 'iuv': predict_u = self.predict_u(x) @@ -209,7 +235,7 @@ class IUV_predict_layer(nn.Module): return_dict['predict_v'] = None # return_dict['predict_u'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device) # return_dict['predict_v'] = torch.zeros(predict_uv_index.shape).to(predict_uv_index.device) - + if self.mode == 'pncc': predict_pncc = self.predict_pncc(x) return_dict['predict_pncc'] = predict_pncc @@ -252,10 +278,11 @@ class Kps_predict_layer(nn.Module): stride=1, padding=1 if final_cov_k == 3 else 0 ) - self.predict_kps = nn.Sequential(add_module, - # nn.BatchNorm2d(feat_dim, momentum=BN_MOMENTUM), - # conv, - ) + self.predict_kps = nn.Sequential( + add_module, + # nn.BatchNorm2d(feat_dim, momentum=BN_MOMENTUM), + # conv, + ) else: self.predict_kps = nn.Conv2d( in_channels=feat_dim, @@ -277,8 +304,16 @@ class Kps_predict_layer(nn.Module): class SmplResNet(nn.Module): - - def __init__(self, resnet_nums, in_channels=3, num_classes=229, last_stride=2, n_extra_feat=0, truncate=0, **kwargs): + def __init__( + self, + resnet_nums, + in_channels=3, + num_classes=229, + last_stride=2, + n_extra_feat=0, + truncate=0, + **kwargs + ): super().__init__() self.inplanes = 64 @@ -287,15 +322,16 @@ class SmplResNet(nn.Module): # self.deconv_with_bias = extra.DECONV_WITH_BIAS block, layers = resnet_spec[resnet_nums] - self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) if truncate < 2 else None - self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) if truncate < 1 else None + self.layer4 = self._make_layer( + block, 512, layers[3], stride=last_stride + ) if truncate < 1 else None self.avg_pooling = nn.AdaptiveAvgPool2d(1) @@ -306,16 +342,26 @@ class SmplResNet(nn.Module): self.n_extra_feat = n_extra_feat if n_extra_feat > 0: - self.trans_conv = nn.Sequential(nn.Conv2d(n_extra_feat + 512*block.expansion, 512*block.expansion, kernel_size=1, bias=False), - nn.BatchNorm2d(512*block.expansion, momentum=BN_MOMENTUM), - nn.ReLU(True)) + self.trans_conv = nn.Sequential( + nn.Conv2d( + n_extra_feat + 512 * block.expansion, + 512 * block.expansion, + kernel_size=1, + bias=False + ), nn.BatchNorm2d(512 * block.expansion, momentum=BN_MOMENTUM), nn.ReLU(True) + ) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) @@ -378,8 +424,7 @@ class SmplResNet(nn.Module): else: state_dict[key] = state_dict_old[key] else: - raise RuntimeError( - 'No state_dict found in checkpoint file {}'.format(pretrained)) + raise RuntimeError('No state_dict found in checkpoint file {}'.format(pretrained)) self.load_state_dict(state_dict, strict=False) else: logger.error('=> imagenet pretrained model dose not exist') @@ -388,7 +433,6 @@ class SmplResNet(nn.Module): class LimbResLayers(nn.Module): - def __init__(self, resnet_nums, inplanes, outplanes=None, groups=1, **kwargs): super().__init__() @@ -407,8 +451,14 @@ class LimbResLayers(nn.Module): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d(self.inplanes * groups, planes * block.expansion * groups, - kernel_size=1, stride=stride, bias=False, groups=groups), + nn.Conv2d( + self.inplanes * groups, + planes * block.expansion * groups, + kernel_size=1, + stride=stride, + bias=False, + groups=groups + ), nn.BatchNorm2d(planes * block.expansion * groups, momentum=BN_MOMENTUM), ) diff --git a/lib/pymafx/models/smpl.py b/lib/pymafx/models/smpl.py index 2ac5405473b792eb5e89ebe30328461c619354e3..0a69eaf24e518545542cdd1eb55c819a549e8d55 100644 --- a/lib/pymafx/models/smpl.py +++ b/lib/pymafx/models/smpl.py @@ -19,6 +19,7 @@ from lib.pymafx.core import path_config, constants SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR + @dataclass class ModelOutput(SMPLXOutput): smpl_joints: Optional[torch.Tensor] = None @@ -33,16 +34,31 @@ class ModelOutput(SMPLXOutput): lfoot_joints: Optional[torch.Tensor] = None rfoot_joints: Optional[torch.Tensor] = None + class SMPL(_SMPL): """ Extension of the official SMPL implementation to support more joints """ - def __init__(self, create_betas=False, create_global_orient=False, create_body_pose=False, create_transl=False, *args, **kwargs): - super().__init__(create_betas=create_betas, - create_global_orient=create_global_orient, - create_body_pose=create_body_pose, - create_transl=create_transl, *args, **kwargs) + def __init__( + self, + create_betas=False, + create_global_orient=False, + create_body_pose=False, + create_transl=False, + *args, + **kwargs + ): + super().__init__( + create_betas=create_betas, + create_global_orient=create_global_orient, + create_body_pose=create_body_pose, + create_transl=create_transl, + *args, + **kwargs + ) joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) - self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) + self.register_buffer( + 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) + ) self.joint_map = torch.tensor(joints, dtype=torch.long) # self.ModelOutput = namedtuple('ModelOutput_', ModelOutput._fields + ('smpl_joints', 'joints_J19',)) # self.ModelOutput.__new__.__defaults__ = (None,) * len(self.ModelOutput._fields) @@ -58,17 +74,19 @@ class SMPL(_SMPL): vertices = smpl_output.vertices joints = torch.cat([smpl_output.joints, extra_joints], dim=1) smpl_joints = smpl_output.joints[:, :24] - joints = joints[:, self.joint_map, :] # [B, 49, 3] + joints = joints[:, self.joint_map, :] # [B, 49, 3] joints_J24 = joints[:, -24:, :] joints_J19 = joints_J24[:, constants.J24_TO_J19, :] - output = ModelOutput(vertices=vertices, - global_orient=smpl_output.global_orient, - body_pose=smpl_output.body_pose, - joints=joints, - joints_J19=joints_J19, - smpl_joints=smpl_joints, - betas=smpl_output.betas, - full_pose=smpl_output.full_pose) + output = ModelOutput( + vertices=vertices, + global_orient=smpl_output.global_orient, + body_pose=smpl_output.body_pose, + joints=joints, + joints_J19=joints_J19, + smpl_joints=smpl_joints, + betas=smpl_output.betas, + full_pose=smpl_output.full_pose + ) return output def get_global_rotation( @@ -107,18 +125,20 @@ class SMPL(_SMPL): batch_size = max(batch_size, len(var)) if global_orient is None: - global_orient = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + global_orient = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() if body_pose is None: - body_pose = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand( - batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous() + body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1 + ).contiguous() # Concatenate all pose vectors full_pose = torch.cat( [global_orient.reshape(-1, 1, 3, 3), body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)], - dim=1) + dim=1 + ) rot_mats = full_pose.view(batch_size, -1, 3, 3) @@ -132,16 +152,15 @@ class SMPL(_SMPL): rel_joints = joints.clone() rel_joints[:, 1:] -= joints[:, self.parents[1:]] - transforms_mat = transform_mat( - rot_mats.reshape(-1, 3, 3), - rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, + 1)).reshape(-1, joints.shape[1], 4, 4) transform_chain = [transforms_mat[:, 0]] for i in range(1, self.parents.shape[0]): # Subtract the joint location at the rest pose # No need for rotation, since it's identity when at rest - curr_res = torch.matmul(transform_chain[self.parents[i]], - transforms_mat[:, i]) + curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) transform_chain.append(curr_res) transforms = torch.stack(transform_chain, dim=1) @@ -230,60 +249,72 @@ class SMPLX(SMPLXLayer): batch_size = max(batch_size, len(var)) if global_orient is None: - global_orient = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + global_orient = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() if body_pose is None: - body_pose = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand( - batch_size, self.NUM_BODY_JOINTS, -1, -1).contiguous() + body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( + batch_size, self.NUM_BODY_JOINTS, -1, -1 + ).contiguous() if left_hand_pose is None: - left_hand_pose = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + left_hand_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, + -1).contiguous() if right_hand_pose is None: - right_hand_pose = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + right_hand_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, + 3).expand(batch_size, 15, -1, + -1).contiguous() if jaw_pose is None: - jaw_pose = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + jaw_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() if leye_pose is None: - leye_pose = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + leye_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() if reye_pose is None: - reye_pose = torch.eye(3, device=device, dtype=dtype).view( - 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + reye_pose = torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, + -1).contiguous() # Concatenate all pose vectors full_pose = torch.cat( - [global_orient.reshape(-1, 1, 3, 3), - body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), - jaw_pose.reshape(-1, 1, 3, 3), - leye_pose.reshape(-1, 1, 3, 3), - reye_pose.reshape(-1, 1, 3, 3), - left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), - right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)], - dim=1) - + [ + global_orient.reshape(-1, 1, 3, 3), + body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), + jaw_pose.reshape(-1, 1, 3, 3), + leye_pose.reshape(-1, 1, 3, 3), + reye_pose.reshape(-1, 1, 3, 3), + left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), + right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) + ], + dim=1 + ) + rot_mats = full_pose.view(batch_size, -1, 3, 3) # Get the joints # NxJx3 array - joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0).expand(batch_size, -1, -1)) + joints = vertices2joints( + self.J_regressor, + self.v_template.unsqueeze(0).expand(batch_size, -1, -1) + ) joints = torch.unsqueeze(joints, dim=-1) rel_joints = joints.clone() rel_joints[:, 1:] -= joints[:, self.parents[1:]] - transforms_mat = transform_mat( - rot_mats.reshape(-1, 3, 3), - rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, + 1)).reshape(-1, joints.shape[1], 4, 4) transform_chain = [transforms_mat[:, 0]] for i in range(1, self.parents.shape[0]): # Subtract the joint location at the rest pose # No need for rotation, since it's identity when at rest - curr_res = torch.matmul(transform_chain[self.parents[i]], - transforms_mat[:, i]) + curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) transform_chain.append(curr_res) transforms = torch.stack(transform_chain, dim=1) @@ -298,7 +329,6 @@ class SMPLX(SMPLXLayer): class SMPLX_ALL(nn.Module): """ Extension of the official SMPLX implementation to support more joints """ - def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs): super().__init__() numBetas = 10 @@ -309,45 +339,72 @@ class SMPLX_ALL(nn.Module): self.genders = ['neutral'] for gender in self.genders: assert gender in ['male', 'female', 'neutral'] - self.model_dict = nn.ModuleDict({gender: SMPLX(path_config.SMPL_MODEL_DIR, - gender=gender, - ext='npz', - num_betas=numBetas, - use_pca=False, batch_size=batch_size, use_face_contour=use_face_contour, num_pca_comps=45, **kwargs) - for gender in self.genders}) + self.model_dict = nn.ModuleDict( + { + gender: SMPLX( + path_config.SMPL_MODEL_DIR, + gender=gender, + ext='npz', + num_betas=numBetas, + use_pca=False, + batch_size=batch_size, + use_face_contour=use_face_contour, + num_pca_comps=45, + **kwargs + ) + for gender in self.genders + } + ) self.model_neutral = self.model_dict['neutral'] joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) - self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) + self.register_buffer( + 'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) + ) self.joint_map = torch.tensor(joints, dtype=torch.long) # smplx_to_smpl.pkl, file source: https://smpl-x.is.tue.mpg.de - smplx_to_smpl = pickle.load(open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb')) - self.register_buffer('smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32)) + smplx_to_smpl = pickle.load( + open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') + ) + self.register_buffer( + 'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32) + ) smpl2limb_vert_faces = get_partial_smpl('smpl') self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long() self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long() # left and right hand joint mapping - smplx2lhand_joints = [constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES] - smplx2rhand_joints = [constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES] + smplx2lhand_joints = [ + constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES + ] + smplx2rhand_joints = [ + constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES + ] self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long) self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long) # left and right foot joint mapping - smplx2lfoot_joints = [constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES] - smplx2rfoot_joints = [constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES] + smplx2lfoot_joints = [ + constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES + ] + smplx2rfoot_joints = [ + constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES + ] self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long) self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long) for g in self.genders: - J_template = torch.einsum('ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template]) - J_dirs = torch.einsum('ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs]) + J_template = torch.einsum( + 'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template] + ) + J_dirs = torch.einsum( + 'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs] + ) self.register_buffer(f'{g}_J_template', J_template) self.register_buffer(f'{g}_J_dirs', J_dirs) - def forward(self, *args, **kwargs): batch_size = kwargs['body_pose'].shape[0] kwargs['get_skin'] = True @@ -357,7 +414,10 @@ class SMPLX_ALL(nn.Module): kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device) # pose for 55 joints: 1, 21, 15, 15, 1, 1, 1 - pose_keys = ['global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', 'leye_pose', 'reye_pose'] + pose_keys = [ + 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', + 'leye_pose', 'reye_pose' + ] param_keys = ['betas'] + pose_keys if kwargs['pose2rot']: for key in pose_keys: @@ -366,7 +426,9 @@ class SMPLX_ALL(nn.Module): # kwargs[key] += self.model_neutral.left_hand_mean # elif key == 'right_hand_pose': # kwargs[key] += self.model_neutral.right_hand_mean - kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([batch_size, -1, 3, 3]) + kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( + [batch_size, -1, 3, 3] + ) if kwargs['body_pose'].shape[1] == 23: # remove hand pose in the body_pose kwargs['body_pose'] = kwargs['body_pose'][:, :21] @@ -406,26 +468,27 @@ class SMPLX_ALL(nn.Module): smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45] joints = torch.cat([smplx_j45, extra_joints], dim=1) smpl_joints = smplx_j45[:, :24] - joints = joints[:, self.joint_map, :] # [B, 49, 3] + joints = joints[:, self.joint_map, :] # [B, 49, 3] joints_J24 = joints[:, -24:, :] joints_J19 = joints_J24[:, constants.J24_TO_J19, :] - output = ModelOutput(vertices=smpl_vertices, - smplx_vertices=smplx_vertices, - lhand_vertices=lhand_vertices, - rhand_vertices=rhand_vertices, - # global_orient=smplx_output.global_orient, - # body_pose=smplx_output.body_pose, - joints=joints, - joints_J19=joints_J19, - smpl_joints=smpl_joints, - # betas=smplx_output.betas, - # full_pose=smplx_output.full_pose, - lhand_joints=lhand_joints, - rhand_joints=rhand_joints, - lfoot_joints=lfoot_joints, - rfoot_joints=rfoot_joints, - face_joints=face_joints, - ) + output = ModelOutput( + vertices=smpl_vertices, + smplx_vertices=smplx_vertices, + lhand_vertices=lhand_vertices, + rhand_vertices=rhand_vertices, + # global_orient=smplx_output.global_orient, + # body_pose=smplx_output.body_pose, + joints=joints, + joints_J19=joints_J19, + smpl_joints=smpl_joints, + # betas=smplx_output.betas, + # full_pose=smplx_output.full_pose, + lhand_joints=lhand_joints, + rhand_joints=rhand_joints, + lfoot_joints=lfoot_joints, + rfoot_joints=rfoot_joints, + face_joints=face_joints, + ) return output # def make_hand_regressor(self): @@ -467,7 +530,7 @@ class SMPLX_ALL(nn.Module): kwargs['gender'] = 2 * torch.ones(batch_size).to(device) else: kwargs['gender'] = gender - + param_keys = ['betas'] gender_idx_list = [] @@ -480,7 +543,9 @@ class SMPLX_ALL(nn.Module): gender_kwargs = {} gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) - J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes(gender_kwargs['betas'], getattr(self, f'{g}_J_dirs')) + J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes( + gender_kwargs['betas'], getattr(self, f'{g}_J_dirs') + ) smplx_joints.append(J) @@ -491,9 +556,10 @@ class SMPLX_ALL(nn.Module): return smplx_joints + class MANO(MANOLayer): """ Extension of the official MANO implementation to support more joints """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, *args, **kwargs): @@ -504,7 +570,9 @@ class MANO(MANOLayer): if kwargs['pose2rot']: for key in pose_keys: if key in kwargs: - kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([batch_size, -1, 3, 3]) + kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( + [batch_size, -1, 3, 3] + ) kwargs['hand_pose'] = kwargs.pop('right_hand_pose') mano_output = super().forward(*args, **kwargs) th_verts = mano_output.vertices @@ -515,15 +583,18 @@ class MANO(MANOLayer): tips = th_verts[:, [745, 317, 445, 556, 673]] th_jtr = torch.cat([th_jtr, tips], 1) # Reorder joints to match visualization utilities - th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] - output = ModelOutput(rhand_vertices=th_verts, - rhand_joints=th_jtr, - ) + th_jtr = th_jtr[:, + [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] + output = ModelOutput( + rhand_vertices=th_verts, + rhand_joints=th_jtr, + ) return output + class FLAME(FLAMELayer): """ Extension of the official FLAME implementation to support more joints """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, *args, **kwargs): @@ -534,30 +605,33 @@ class FLAME(FLAMELayer): if kwargs['pose2rot']: for key in pose_keys: if key in kwargs: - kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view([batch_size, -1, 3, 3]) + kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( + [batch_size, -1, 3, 3] + ) flame_output = super().forward(*args, **kwargs) - output = ModelOutput(flame_vertices=flame_output.vertices, - face_joints=flame_output.joints[:, 5:], - ) + output = ModelOutput( + flame_vertices=flame_output.vertices, + face_joints=flame_output.joints[:, 5:], + ) return output + class SMPL_Family(): def __init__(self, model_type='smpl', *args, **kwargs): if model_type == 'smpl': - self.model = SMPL( - model_path=SMPL_MODEL_DIR, - *args, **kwargs - ) + self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs) elif model_type == 'smplx': self.model = SMPLX_ALL(*args, **kwargs) elif model_type == 'mano': - self.model = MANO(model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs) + self.model = MANO( + model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs + ) elif model_type == 'flame': self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs) def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) - + def get_tpose(self, *args, **kwargs): return self.model.get_tpose(*args, **kwargs) @@ -570,14 +644,17 @@ class SMPL_Family(): # else: # self.model.cuda(device) + def get_smpl_faces(): smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1) return smpl.faces + def get_smplx_faces(): smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1) return smplx.faces + def get_mano_faces(hand_type='right'): assert hand_type in ['right', 'left'] is_rhand = True if hand_type == 'right' else False @@ -585,11 +662,13 @@ def get_mano_faces(hand_type='right'): return mano.faces + def get_flame_faces(): flame = FLAME(SMPL_MODEL_DIR, batch_size=1) return flame.faces + def get_model_faces(type='smpl'): if type == 'smpl': return get_smpl_faces() @@ -600,6 +679,7 @@ def get_model_faces(type='smpl'): elif type == 'flame': return get_flame_faces() + def get_model_tpose(type='smpl'): if type == 'smpl': return get_smpl_tpose() @@ -610,43 +690,64 @@ def get_model_tpose(type='smpl'): elif type == 'flame': return get_flame_tpose() + def get_smpl_tpose(): - smpl = SMPL(create_betas=True, create_global_orient=True, create_body_pose=True, model_path=SMPL_MODEL_DIR, batch_size=1) + smpl = SMPL( + create_betas=True, + create_global_orient=True, + create_body_pose=True, + model_path=SMPL_MODEL_DIR, + batch_size=1 + ) vertices = smpl().vertices[0] return vertices.detach() + def get_smpl_tpose_joint(): - smpl = SMPL(create_betas=True, create_global_orient=True, create_body_pose=True, model_path=SMPL_MODEL_DIR, batch_size=1) + smpl = SMPL( + create_betas=True, + create_global_orient=True, + create_body_pose=True, + model_path=SMPL_MODEL_DIR, + batch_size=1 + ) tpose_joint = smpl().smpl_joints[0] return tpose_joint.detach() + def get_smplx_tpose(): smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) vertices = smplx().vertices[0] return vertices + def get_smplx_tpose_joint(): smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) tpose_joint = smplx().joints[0] return tpose_joint + def get_mano_tpose(): mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True) - vertices = mano(global_orient=torch.zeros(1, 3), - right_hand_pose=torch.zeros(1, 15*3)).rhand_vertices[0] + vertices = mano(global_orient=torch.zeros(1, 3), + right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0] return vertices + def get_flame_tpose(): flame = FLAME(SMPL_MODEL_DIR, batch_size=1) vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0] return vertices + def get_part_joints(smpl_joints): batch_size = smpl_joints.shape[0] # part_joints = torch.zeros().to(smpl_joints.device) - one_seg_pairs = [(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17)] + one_seg_pairs = [ + (0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17) + ] two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)] one_seg_pairs.extend(two_seg_pairs) @@ -660,12 +761,13 @@ def get_part_joints(smpl_joints): part_joints.append(new_joint) for j_p in single_joints: - part_joints.append(smpl_joints[:, j_p:j_p+1]) + part_joints.append(smpl_joints[:, j_p:j_p + 1]) part_joints = torch.cat(part_joints, dim=1) return part_joints + def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): body_model_faces = get_model_faces(body_model) @@ -680,9 +782,13 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']} else: if part in ['lhand', 'rhand']: - with open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb') as json_file: + with open( + os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb' + ) as json_file: smplx_mano_id = pickle.load(json_file) - with open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') as json_file: + with open( + os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb' + ) as json_file: smplx_smpl_id = pickle.load(json_file) smplx_tpose = get_smplx_tpose() @@ -701,13 +807,17 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): smpl2mano_id.append(int(v_closest)) smpl2mano_vids = np.array(smpl2mano_id).astype(np.long) - mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left').astype(np.long) + mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left' + ).astype(np.long) np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces) part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces} elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']: - with open(os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), 'rb') as json_file: + with open( + os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), + 'rb' + ) as json_file: smplx_part_id = json.load(json_file) # main_body_part = list(smplx_part_id.keys()) @@ -716,12 +826,30 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): if part == 'face': selected_body_part = ['head'] elif part == 'arm': - selected_body_part = ['rightHand', 'leftArm', 'leftShoulder', 'rightShoulder', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',] + selected_body_part = [ + 'rightHand', + 'leftArm', + 'leftShoulder', + 'rightShoulder', + 'rightArm', + 'leftHandIndex1', + 'rightHandIndex1', + 'leftForeArm', + 'rightForeArm', + 'leftHand', + ] # selected_body_part = ['rightHand', 'leftArm', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',] elif part == 'forearm': - selected_body_part = ['rightHand', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',] + selected_body_part = [ + 'rightHand', + 'leftHandIndex1', + 'rightHandIndex1', + 'leftForeArm', + 'rightForeArm', + 'leftHand', + ] elif part == 'arm_eval': - selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm'] + selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm'] elif part == 'larm': # selected_body_part = ['leftArm', 'leftForeArm'] selected_body_part = ['leftForeArm'] @@ -749,7 +877,7 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces) part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces} - + elif part in ['lwrist', 'rwrist']: if body_model == 'smplx': @@ -765,11 +893,11 @@ def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): wrist_vids = [] for vid, vt in enumerate(body_model_verts): - v_j_dist = torch.sum((vt - wrist_joint) ** 2) + v_j_dist = torch.sum((vt - wrist_joint)**2) if v_j_dist < dist: wrist_vids.append(vid) - + wrist_vids = np.array(wrist_vids) part_body_fid = [] diff --git a/lib/pymafx/models/transformers/bert/__init__.py b/lib/pymafx/models/transformers/bert/__init__.py index 4e18c1b84248ce3d4425cad34e5bb0f08fe32fec..0432a1e92856c438e5fd2f550dc5029a78fa354c 100644 --- a/lib/pymafx/models/transformers/bert/__init__.py +++ b/lib/pymafx/models/transformers/bert/__init__.py @@ -1,8 +1,9 @@ __version__ = "1.0.0" -from .modeling_bert import (BertConfig, BertModel, - load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, - BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) +from .modeling_bert import ( + BertConfig, BertModel, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + BERT_PRETRAINED_CONFIG_ARCHIVE_MAP +) from .modeling_graphormer import Graphormer @@ -10,7 +11,9 @@ from .modeling_graphormer import Graphormer # from .e2e_hand_network import Graphormer_Hand_Network -from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, - PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) +from .modeling_utils import ( + WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_layer, + Conv1D +) from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) diff --git a/lib/pymafx/models/transformers/bert/e2e_body_network.py b/lib/pymafx/models/transformers/bert/e2e_body_network.py index 22faf2f7c3a3a58047d5553c179d332827f6daa6..9d1c75e276aa18fa1e8f2d865cbef7a275f71b8c 100644 --- a/lib/pymafx/models/transformers/bert/e2e_body_network.py +++ b/lib/pymafx/models/transformers/bert/e2e_body_network.py @@ -7,6 +7,7 @@ Licensed under the MIT license. import torch import src.modeling.data.config as cfg + class Graphormer_Body_Network(torch.nn.Module): ''' End-to-end Graphormer network for human pose and mesh reconstruction from a single image. @@ -24,25 +25,27 @@ class Graphormer_Body_Network(torch.nn.Module): self.cam_param_fc3 = torch.nn.Linear(250, 3) self.grid_feat_dim = torch.nn.Linear(1024, 2051) - def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False): batch_size = images.size(0) # Generate T-pose template mesh - template_pose = torch.zeros((1,72)) - template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord + template_pose = torch.zeros((1, 72)) + template_pose[:, 0] = 3.1416 # Rectify "upside down" reference mesh in global coord template_pose = template_pose.cuda(self.config.device) - template_betas = torch.zeros((1,10)).cuda(self.config.device) + template_betas = torch.zeros((1, 10)).cuda(self.config.device) template_vertices = smpl(template_pose, template_betas) # template mesh simplification template_vertices_sub = mesh_sampler.downsample(template_vertices) template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2) - print('template_vertices', template_vertices.shape, template_vertices_sub.shape, template_vertices_sub2.shape) + print( + 'template_vertices', template_vertices.shape, template_vertices_sub.shape, + template_vertices_sub2.shape + ) - # template mesh-to-joint regression + # template mesh-to-joint regression template_3d_joints = smpl.get_h36m_joints(template_vertices) - template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:] - template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:] + template_pelvis = template_3d_joints[:, cfg.H36M_J17_NAME.index('Pelvis'), :] + template_3d_joints = template_3d_joints[:, cfg.H36M_J17_TO_J14, :] num_joints = template_3d_joints.shape[1] # normalize @@ -50,7 +53,7 @@ class Graphormer_Body_Network(torch.nn.Module): template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :] # concatinate template joints and template vertices, and then duplicate to batch size - ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1) + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2], dim=1) ref_vertices = ref_vertices.expand(batch_size, -1, -1) print('ref_vertices', ref_vertices.shape) @@ -62,7 +65,7 @@ class Graphormer_Body_Network(torch.nn.Module): print('image_feat', image_feat.shape) # process grid features grid_feat = torch.flatten(grid_feat, start_dim=2) - grid_feat = grid_feat.transpose(1,2) + grid_feat = grid_feat.transpose(1, 2) print('grid_feat bf', grid_feat.shape) grid_feat = self.grid_feat_dim(grid_feat) print('grid_feat', grid_feat.shape) @@ -70,42 +73,43 @@ class Graphormer_Body_Network(torch.nn.Module): features = torch.cat([ref_vertices, image_feat], dim=2) print('features', features.shape, ref_vertices.shape, image_feat.shape) # prepare input tokens including joint/vertex queries and grid features - features = torch.cat([features, grid_feat],dim=1) + features = torch.cat([features, grid_feat], dim=1) print('features', features.shape) - if is_train==True: + if is_train == True: # apply mask vertex/joint modeling # meta_masks is a tensor of all the masks, randomly generated in dataloader # we pre-define a [MASK] token, which is a floating-value vector with 0.01s - special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01 + special_token = torch.ones_like(features[:, :-49, :]).cuda() * 0.01 print('special_token', special_token.shape, meta_masks.shape) print('meta_masks', torch.unique(meta_masks)) - features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) + features[:, :-49, : + ] = features[:, :-49, :] * meta_masks + special_token * (1 - meta_masks) # forward pass - if self.config.output_attentions==True: + if self.config.output_attentions == True: features, hidden_states, att = self.trans_encoder(features) else: features = self.trans_encoder(features) - pred_3d_joints = features[:,:num_joints,:] - pred_vertices_sub2 = features[:,num_joints:-49,:] + pred_3d_joints = features[:, :num_joints, :] + pred_vertices_sub2 = features[:, num_joints:-49, :] # learn camera parameters x = self.cam_param_fc(pred_vertices_sub2) - x = x.transpose(1,2) + x = x.transpose(1, 2) x = self.cam_param_fc2(x) x = self.cam_param_fc3(x) - cam_param = x.transpose(1,2) + cam_param = x.transpose(1, 2) cam_param = cam_param.squeeze() - temp_transpose = pred_vertices_sub2.transpose(1,2) + temp_transpose = pred_vertices_sub2.transpose(1, 2) pred_vertices_sub = self.upsampling(temp_transpose) pred_vertices_full = self.upsampling2(pred_vertices_sub) - pred_vertices_sub = pred_vertices_sub.transpose(1,2) - pred_vertices_full = pred_vertices_full.transpose(1,2) + pred_vertices_sub = pred_vertices_sub.transpose(1, 2) + pred_vertices_full = pred_vertices_full.transpose(1, 2) - if self.config.output_attentions==True: + if self.config.output_attentions == True: return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att else: - return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full \ No newline at end of file + return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full diff --git a/lib/pymafx/models/transformers/bert/e2e_hand_network.py b/lib/pymafx/models/transformers/bert/e2e_hand_network.py index 7030d0f6f1ec7e6741d4c5f67b2e78eaad708a5f..410968c4abc63e1ae8281b2e0297c8eef4e7bbcf 100644 --- a/lib/pymafx/models/transformers/bert/e2e_hand_network.py +++ b/lib/pymafx/models/transformers/bert/e2e_hand_network.py @@ -7,6 +7,7 @@ Licensed under the MIT license. import torch import src.modeling.data.config as cfg + class Graphormer_Hand_Network(torch.nn.Module): ''' End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. @@ -18,31 +19,31 @@ class Graphormer_Hand_Network(torch.nn.Module): self.trans_encoder = trans_encoder self.upsampling = torch.nn.Linear(195, 778) self.cam_param_fc = torch.nn.Linear(3, 1) - self.cam_param_fc2 = torch.nn.Linear(195+21, 150) + self.cam_param_fc2 = torch.nn.Linear(195 + 21, 150) self.cam_param_fc3 = torch.nn.Linear(150, 3) self.grid_feat_dim = torch.nn.Linear(1024, 2051) def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False): batch_size = images.size(0) # Generate T-pose template mesh - template_pose = torch.zeros((1,48)) + template_pose = torch.zeros((1, 48)) template_pose = template_pose.cuda() - template_betas = torch.zeros((1,10)).cuda() + template_betas = torch.zeros((1, 10)).cuda() template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas) - template_vertices = template_vertices/1000.0 - template_3d_joints = template_3d_joints/1000.0 + template_vertices = template_vertices / 1000.0 + template_3d_joints = template_3d_joints / 1000.0 template_vertices_sub = mesh_sampler.downsample(template_vertices) # normalize - template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:] + template_root = template_3d_joints[:, cfg.J_NAME.index('Wrist'), :] template_3d_joints = template_3d_joints - template_root[:, None, :] template_vertices = template_vertices - template_root[:, None, :] template_vertices_sub = template_vertices_sub - template_root[:, None, :] num_joints = template_3d_joints.shape[1] # concatinate template joints and template vertices, and then duplicate to batch size - ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1) + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub], dim=1) ref_vertices = ref_vertices.expand(batch_size, -1, -1) # extract grid features and global image features using a CNN backbone @@ -51,42 +52,43 @@ class Graphormer_Hand_Network(torch.nn.Module): image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) # process grid features grid_feat = torch.flatten(grid_feat, start_dim=2) - grid_feat = grid_feat.transpose(1,2) + grid_feat = grid_feat.transpose(1, 2) grid_feat = self.grid_feat_dim(grid_feat) # concatinate image feat and template mesh to form the joint/vertex queries features = torch.cat([ref_vertices, image_feat], dim=2) # prepare input tokens including joint/vertex queries and grid features - features = torch.cat([features, grid_feat],dim=1) + features = torch.cat([features, grid_feat], dim=1) - if is_train==True: + if is_train == True: # apply mask vertex/joint modeling # meta_masks is a tensor of all the masks, randomly generated in dataloader - # we pre-define a [MASK] token, which is a floating-value vector with 0.01s - special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01 - features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + special_token = torch.ones_like(features[:, :-49, :]).cuda() * 0.01 + features[:, :-49, : + ] = features[:, :-49, :] * meta_masks + special_token * (1 - meta_masks) # forward pass - if self.config.output_attentions==True: + if self.config.output_attentions == True: features, hidden_states, att = self.trans_encoder(features) else: features = self.trans_encoder(features) - pred_3d_joints = features[:,:num_joints,:] - pred_vertices_sub = features[:,num_joints:-49,:] + pred_3d_joints = features[:, :num_joints, :] + pred_vertices_sub = features[:, num_joints:-49, :] # learn camera parameters - x = self.cam_param_fc(features[:,:-49,:]) - x = x.transpose(1,2) + x = self.cam_param_fc(features[:, :-49, :]) + x = x.transpose(1, 2) x = self.cam_param_fc2(x) x = self.cam_param_fc3(x) - cam_param = x.transpose(1,2) + cam_param = x.transpose(1, 2) cam_param = cam_param.squeeze() - temp_transpose = pred_vertices_sub.transpose(1,2) + temp_transpose = pred_vertices_sub.transpose(1, 2) pred_vertices = self.upsampling(temp_transpose) - pred_vertices = pred_vertices.transpose(1,2) + pred_vertices = pred_vertices.transpose(1, 2) - if self.config.output_attentions==True: + if self.config.output_attentions == True: return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att else: - return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices \ No newline at end of file + return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices diff --git a/lib/pymafx/models/transformers/bert/file_utils.py b/lib/pymafx/models/transformers/bert/file_utils.py index fd655cec0ed897d5abaea8289a6395aaa672d767..ee58bed427f90be254caee9a0733d81ae92c8711 100644 --- a/lib/pymafx/models/transformers/bert/file_utils.py +++ b/lib/pymafx/models/transformers/bert/file_utils.py @@ -26,8 +26,8 @@ try: torch_cache_home = _get_torch_home() except ImportError: torch_cache_home = os.path.expanduser( - os.getenv('TORCH_HOME', os.path.join( - os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) + os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')) + ) default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') try: @@ -38,12 +38,12 @@ except ImportError: try: from pathlib import Path PYTORCH_PRETRAINED_BERT_CACHE = Path( - os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) + os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path) + ) except (AttributeError, ImportError): - PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', - default_cache_path) + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path) -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) # pylint: disable=invalid-name def url_to_filename(url, etag=None): @@ -138,7 +138,6 @@ def s3_request(func): Wrapper function for s3 requests in order to create more helpful error messages. """ - @wraps(func) def wrapper(url, *args, **kwargs): try: @@ -175,7 +174,7 @@ def http_get(url, temp_file): total = int(content_length) if content_length is not None else None progress = tqdm(unit="B", total=total) for chunk in req.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks + if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) progress.close() @@ -251,7 +250,7 @@ def get_from_cache(url, cache_dir=None): with open(meta_path, 'w') as meta_file: output_string = json.dumps(meta) if sys.version_info[0] == 2 and isinstance(output_string, str): - output_string = unicode(output_string, 'utf-8') # The beauty of python 2 + output_string = unicode(output_string, 'utf-8') # The beauty of python 2 meta_file.write(output_string) logger.info("removing temp file %s", temp_file.name) diff --git a/lib/pymafx/models/transformers/bert/modeling_bert.py b/lib/pymafx/models/transformers/bert/modeling_bert.py index 738ebe6ec5f0697d0ec526fab4973489d01afd8e..c4a7f27f1bc0e69d87ac3747b8d8acfafb03b4b8 100644 --- a/lib/pymafx/models/transformers/bert/modeling_bert.py +++ b/lib/pymafx/models/transformers/bert/modeling_bert.py @@ -28,41 +28,69 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss -from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, - prune_linear_layer, add_start_docstrings) +from .modeling_utils import ( + WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer, + add_start_docstrings +) logger = logging.getLogger(__name__) BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", - 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", - 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", - 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", - 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", - 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", - 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", + 'bert-base-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", + 'bert-large-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", + 'bert-base-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", + 'bert-large-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", + 'bert-base-multilingual-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", + 'bert-base-multilingual-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", + 'bert-base-chinese': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", + 'bert-base-german-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", + 'bert-large-cased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", + 'bert-large-uncased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-large-cased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", + 'bert-base-cased-finetuned-mrpc': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", } BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", - 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", - 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", - 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", - 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", - 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", - 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", + 'bert-base-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", + 'bert-large-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", + 'bert-base-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", + 'bert-large-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", + 'bert-base-multilingual-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", + 'bert-base-multilingual-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", + 'bert-base-chinese': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", + 'bert-base-german-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", + 'bert-large-uncased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", + 'bert-large-cased-whole-word-masking': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", + 'bert-large-uncased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", + 'bert-large-cased-whole-word-masking-finetuned-squad': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", + 'bert-base-cased-finetuned-mrpc': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", } @@ -74,8 +102,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): import numpy as np import tensorflow as tf except ImportError: - logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) raise tf_path = os.path.abspath(tf_checkpoint_path) logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) @@ -180,23 +210,26 @@ class BertConfig(PretrainedConfig): """ pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP - def __init__(self, - vocab_size_or_config_json_file=30522, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - layer_norm_eps=1e-12, - **kwargs): + def __init__( + self, + vocab_size_or_config_json_file=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + **kwargs + ): super(BertConfig, self).__init__(**kwargs) - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): + if isinstance( + vocab_size_or_config_json_file, str + ) or (sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode)): with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): @@ -215,9 +248,10 @@ class BertConfig(PretrainedConfig): self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps else: - raise ValueError("First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)") - + raise ValueError( + "First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)" + ) # try: @@ -240,6 +274,7 @@ class BertLayerNorm(nn.Module): x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias + class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ @@ -278,7 +313,8 @@ class BertSelfAttention(nn.Module): if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) self.output_attentions = config.output_attentions self.num_attention_heads = config.num_attention_heads @@ -325,10 +361,10 @@ class BertSelfAttention(nn.Module): context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, ) return outputs @@ -372,7 +408,7 @@ class BertAttention(nn.Module): def forward(self, input_tensor, attention_mask, head_mask=None): self_outputs = self.self(input_tensor, attention_mask, head_mask) attention_output = self.output(self_outputs[0], input_tensor) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them return outputs @@ -380,7 +416,8 @@ class BertIntermediate(nn.Module): def __init__(self, config): super(BertIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + if isinstance(config.hidden_act, str + ) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act @@ -417,7 +454,7 @@ class BertLayer(nn.Module): attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them return outputs @@ -433,24 +470,24 @@ class BertEncoder(nn.Module): all_attentions = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) hidden_states = layer_outputs[0] if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1], ) # Add last layer if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) - outputs = (hidden_states,) + outputs = (hidden_states, ) if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) + outputs = outputs + (all_hidden_states, ) if self.output_attentions: - outputs = outputs + (all_attentions,) - return outputs # outputs, (hidden states), (attentions) + outputs = outputs + (all_attentions, ) + return outputs # outputs, (hidden states), (attentions) class BertPooler(nn.Module): @@ -472,7 +509,8 @@ class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super(BertPredictionHeadTransform, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + if isinstance(config.hidden_act, str + ) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act @@ -492,9 +530,7 @@ class BertLMPredictionHead(nn.Module): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, - config.vocab_size, - bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) @@ -620,8 +656,11 @@ BERT_INPUTS_DOCSTRING = r""" ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ -@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", - BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) + +@add_start_docstrings( + "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) class BertModel(BertPreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: @@ -675,7 +714,14 @@ class BertModel(BertPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None): + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None + ): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: @@ -693,7 +739,9 @@ class BertModel(BertPreTrainedModel): # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # Prepare head mask if needed @@ -706,25 +754,36 @@ class BertModel(BertPreTrainedModel): head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) elif head_mask.dim() == 2: - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer - head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1 + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility else: head_mask = [None] * self.config.num_hidden_layers - embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) - encoder_outputs = self.encoder(embedding_output, - extended_attention_mask, - head_mask=head_mask) + embedding_output = self.embeddings( + input_ids, position_ids=position_ids, token_type_ids=token_type_ids + ) + encoder_outputs = self.encoder( + embedding_output, extended_attention_mask, head_mask=head_mask + ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) - outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here - return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + outputs = ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: +@add_start_docstrings( + """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and a `next sentence prediction (classification)` head. """, - BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) class BertForPreTraining(BertPreTrainedModel): r""" **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: @@ -777,31 +836,54 @@ class BertForPreTraining(BertPreTrainedModel): """ Make sure we are sharing the input and output embeddings. Export to TorchScript can't handle parameter sharing so we are cloning them instead. """ - self._tie_or_clone_weights(self.cls.predictions.decoder, - self.bert.embeddings.word_embeddings) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, - next_sentence_label=None, position_ids=None, head_mask=None): - outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + self._tie_or_clone_weights( + self.cls.predictions.decoder, self.bert.embeddings.word_embeddings + ) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + outputs = ( + prediction_scores, + seq_relationship_score, + ) + outputs[2:] # add hidden states and attention if they are here if masked_lm_labels is not None and next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) + ) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) + ) total_loss = masked_lm_loss + next_sentence_loss - outputs = (total_loss,) + outputs + outputs = (total_loss, ) + outputs - return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) + return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, - BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) +@add_start_docstrings( + """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING, + BERT_INPUTS_DOCSTRING +) class BertForMaskedLM(BertPreTrainedModel): r""" **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: @@ -847,28 +929,46 @@ class BertForMaskedLM(BertPreTrainedModel): """ Make sure we are sharing the input and output embeddings. Export to TorchScript can't handle parameter sharing so we are cloning them instead. """ - self._tie_or_clone_weights(self.cls.predictions.decoder, - self.bert.embeddings.word_embeddings) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, - position_ids=None, head_mask=None): - outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + self._tie_or_clone_weights( + self.cls.predictions.decoder, self.bert.embeddings.word_embeddings + ) + + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) - outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here + outputs = (prediction_scores, + ) + outputs[2:] # Add hidden states and attention is they are here if masked_lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) - outputs = (masked_lm_loss,) + outputs + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) + ) + outputs = (masked_lm_loss, ) + outputs - return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, - BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) class BertForNextSentencePrediction(BertPreTrainedModel): r""" **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: @@ -909,26 +1009,42 @@ class BertForNextSentencePrediction(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, - position_ids=None, head_mask=None): - outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) pooled_output = outputs[1] seq_relationship_score = self.cls(pooled_output) - outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here + outputs = (seq_relationship_score, + ) + outputs[2:] # add hidden states and attention if they are here if next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - outputs = (next_sentence_loss,) + outputs + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) + ) + outputs = (next_sentence_loss, ) + outputs - return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) + return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of - the pooled output) e.g. for GLUE tasks. """, - BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) +@add_start_docstrings( + """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) class BertForSequenceClassification(BertPreTrainedModel): r""" **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: @@ -972,16 +1088,28 @@ class BertForSequenceClassification(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, - position_ids=None, head_mask=None): - outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + outputs = (logits, ) + outputs[2:] # add hidden states and attention if they are here if labels is not None: if self.num_labels == 1: @@ -991,14 +1119,15 @@ class BertForSequenceClassification(BertPreTrainedModel): else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - outputs = (loss,) + outputs + outputs = (loss, ) + outputs - return outputs # (loss), logits, (hidden_states), (attentions) + return outputs # (loss), logits, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of - the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, - BERT_START_DOCSTRING) +@add_start_docstrings( + """Bert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, BERT_START_DOCSTRING +) class BertForMultipleChoice(BertPreTrainedModel): r""" Inputs: @@ -1078,35 +1207,56 @@ class BertForMultipleChoice(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, - position_ids=None, head_mask=None): + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None + ): num_choices = input_ids.shape[1] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids, - attention_mask=flat_attention_mask, head_mask=head_mask) + flat_position_ids = position_ids.view( + -1, position_ids.size(-1) + ) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view( + -1, token_type_ids.size(-1) + ) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view( + -1, attention_mask.size(-1) + ) if attention_mask is not None else None + outputs = self.bert( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask + ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, num_choices) - outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + outputs = (reshaped_logits, + ) + outputs[2:] # add hidden states and attention if they are here if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) - outputs = (loss,) + outputs + outputs = (loss, ) + outputs - return outputs # (loss), reshaped_logits, (hidden_states), (attentions) + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of +@add_start_docstrings( + """Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, - BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) class BertForTokenClassification(BertPreTrainedModel): r""" **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: @@ -1148,16 +1298,28 @@ class BertForTokenClassification(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, - position_ids=None, head_mask=None): - outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + outputs = (logits, ) + outputs[2:] # add hidden states and attention if they are here if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss @@ -1168,14 +1330,16 @@ class BertForTokenClassification(BertPreTrainedModel): loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - outputs = (loss,) + outputs + outputs = (loss, ) + outputs - return outputs # (loss), scores, (hidden_states), (attentions) + return outputs # (loss), scores, (hidden_states), (attentions) -@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of +@add_start_docstrings( + """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, - BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) + BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING +) class BertForQuestionAnswering(BertPreTrainedModel): r""" **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: @@ -1224,10 +1388,23 @@ class BertForQuestionAnswering(BertPreTrainedModel): self.apply(self.init_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, - end_positions=None, position_ids=None, head_mask=None): - outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + def forward( + self, + input_ids, + token_type_ids=None, + attention_mask=None, + start_positions=None, + end_positions=None, + position_ids=None, + head_mask=None + ): + outputs = self.bert( + input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) @@ -1235,7 +1412,10 @@ class BertForQuestionAnswering(BertPreTrainedModel): start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) - outputs = (start_logits, end_logits,) + outputs[2:] + outputs = ( + start_logits, + end_logits, + ) + outputs[2:] if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: @@ -1251,6 +1431,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - outputs = (total_loss,) + outputs + outputs = (total_loss, ) + outputs - return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) diff --git a/lib/pymafx/models/transformers/bert/modeling_graphormer.py b/lib/pymafx/models/transformers/bert/modeling_graphormer.py index 91f2b869511ea6228bef40bb0d30b45c3194ce95..e318af8a45d34148e0db68f42181f692afbf8754 100644 --- a/lib/pymafx/models/transformers/bert/modeling_graphormer.py +++ b/lib/pymafx/models/transformers/bert/modeling_graphormer.py @@ -16,6 +16,7 @@ from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, Bert # import src.modeling.data.config as cfg # from src.modeling._gcnn import GraphConvolution, GraphResBlock from .modeling_utils import prune_linear_layer + LayerNormClass = torch.nn.LayerNorm BertLayerNorm = torch.nn.LayerNorm @@ -26,7 +27,8 @@ class BertSelfAttention(nn.Module): if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) self.output_attentions = config.output_attentions self.num_attention_heads = config.num_attention_heads @@ -44,8 +46,7 @@ class BertSelfAttention(nn.Module): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) - def forward(self, hidden_states, attention_mask, head_mask=None, - history_state=None): + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): if history_state is not None: raise x_states = torch.cat([history_state, hidden_states], dim=1) @@ -57,7 +58,10 @@ class BertSelfAttention(nn.Module): mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) - print('mixed_query_layer', mixed_query_layer.shape, mixed_key_layer.shape, mixed_value_layer.shape) + print( + 'mixed_query_layer', mixed_query_layer.shape, mixed_key_layer.shape, + mixed_value_layer.shape + ) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) @@ -84,12 +88,13 @@ class BertSelfAttention(nn.Module): context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer, ) return outputs + class BertAttention(nn.Module): def __init__(self, config): super(BertAttention, self).__init__() @@ -113,12 +118,10 @@ class BertAttention(nn.Module): self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - def forward(self, input_tensor, attention_mask, head_mask=None, - history_state=None): - self_outputs = self.self(input_tensor, attention_mask, head_mask, - history_state) + def forward(self, input_tensor, attention_mask, head_mask=None, history_state=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask, history_state) attention_output = self.output(self_outputs[0], input_tensor) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them return outputs @@ -130,45 +133,46 @@ class GraphormerLayer(nn.Module): self.mesh_type = config.mesh_type if self.has_graph_conv == True: - if self.mesh_type=='hand': - self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) - elif self.mesh_type=='body': - self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) - + if self.mesh_type == 'hand': + self.graph_conv = GraphResBlock( + config.hidden_size, config.hidden_size, mesh_type=self.mesh_type + ) + elif self.mesh_type == 'body': + self.graph_conv = GraphResBlock( + config.hidden_size, config.hidden_size, mesh_type=self.mesh_type + ) + self.intermediate = BertIntermediate(config) self.output = BertOutput(config) - def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, - history_state=None): - attention_outputs = self.attention(hidden_states, attention_mask, - head_mask, history_state) + def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, history_state=None): + attention_outputs = self.attention(hidden_states, attention_mask, head_mask, history_state) attention_output = attention_outputs[0] - if self.has_graph_conv==True: + if self.has_graph_conv == True: if self.mesh_type == 'body': - joints = attention_output[:,0:14,:] - vertices = attention_output[:,14:-49,:] - img_tokens = attention_output[:,-49:,:] + joints = attention_output[:, 0:14, :] + vertices = attention_output[:, 14:-49, :] + img_tokens = attention_output[:, -49:, :] elif self.mesh_type == 'hand': - joints = attention_output[:,0:21,:] - vertices = attention_output[:,21:-49,:] - img_tokens = attention_output[:,-49:,:] + joints = attention_output[:, 0:21, :] + vertices = attention_output[:, 21:-49, :] + img_tokens = attention_output[:, -49:, :] vertices = self.graph_conv(vertices) - joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1) + joints_vertices = torch.cat([joints, vertices, img_tokens], dim=1) else: joints_vertices = attention_output intermediate_output = self.intermediate(joints_vertices) layer_output = self.output(intermediate_output, joints_vertices) print('layer_output', layer_output.shape) - outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + outputs = (layer_output, ) + attention_outputs[1:] # add attentions if we output them return outputs - def forward(self, hidden_states, attention_mask, head_mask=None, - history_state=None): - return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state) + def forward(self, hidden_states, attention_mask, head_mask=None, history_state=None): + return self.MHA_GCN(hidden_states, attention_mask, head_mask, history_state) class GraphormerEncoder(nn.Module): @@ -176,36 +180,36 @@ class GraphormerEncoder(nn.Module): super(GraphormerEncoder, self).__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList( + [GraphormerLayer(config) for _ in range(config.num_hidden_layers)] + ) - def forward(self, hidden_states, attention_mask, head_mask=None, - encoder_history_states=None): + def forward(self, hidden_states, attention_mask, head_mask=None, encoder_history_states=None): all_hidden_states = () all_attentions = () for i, layer_module in enumerate(self.layer): if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) history_state = None if encoder_history_states is None else encoder_history_states[i] - layer_outputs = layer_module( - hidden_states, attention_mask, head_mask[i], - history_state) + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], history_state) hidden_states = layer_outputs[0] if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1], ) # Add last layer if self.output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) - outputs = (hidden_states,) + outputs = (hidden_states, ) if self.output_hidden_states: - outputs = outputs + (all_hidden_states,) + outputs = outputs + (all_hidden_states, ) if self.output_attentions: - outputs = outputs + (all_attentions,) + outputs = outputs + (all_attentions, ) + + return outputs # outputs, (hidden states), (attentions) - return outputs # outputs, (hidden states), (attentions) class EncoderBlock(BertPreTrainedModel): def __init__(self, config): @@ -215,7 +219,7 @@ class EncoderBlock(BertPreTrainedModel): self.encoder = GraphormerEncoder(config) # self.pooler = BertPooler(config) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.img_dim = config.img_feature_dim + self.img_dim = config.img_feature_dim try: self.use_img_layernorm = config.use_img_layernorm @@ -237,12 +241,19 @@ class EncoderBlock(BertPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, - position_ids=None, head_mask=None): + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None + ): batch_size = len(img_feats) seq_length = len(img_feats[0]) - input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).cuda() + input_ids = torch.zeros([batch_size, seq_length], dtype=torch.long).cuda() if position_ids is None: position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) @@ -251,7 +262,10 @@ class EncoderBlock(BertPreTrainedModel): print('position_ids', seq_length, position_ids.shape) position_embeddings = self.position_embeddings(position_ids) - print('position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, self.config.hidden_size) + print( + 'position_embeddings', position_embeddings.shape, self.config.max_position_embeddings, + self.config.hidden_size + ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -270,7 +284,9 @@ class EncoderBlock(BertPreTrainedModel): else: raise NotImplementedError - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 if head_mask is not None: @@ -279,8 +295,12 @@ class EncoderBlock(BertPreTrainedModel): head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) elif head_mask.dim() == 2: - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer - head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze( + -1 + ) # We can specify head_mask for each layer + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to fload if need + fp16 compatibility else: head_mask = [None] * self.config.num_hidden_layers @@ -297,20 +317,20 @@ class EncoderBlock(BertPreTrainedModel): embeddings = self.dropout(embeddings) print('extended_attention_mask', extended_attention_mask.shape) - encoder_outputs = self.encoder(embeddings, - extended_attention_mask, head_mask=head_mask) + encoder_outputs = self.encoder(embeddings, extended_attention_mask, head_mask=head_mask) sequence_output = encoder_outputs[0] - outputs = (sequence_output,) + outputs = (sequence_output, ) if self.config.output_hidden_states: all_hidden_states = encoder_outputs[1] - outputs = outputs + (all_hidden_states,) + outputs = outputs + (all_hidden_states, ) if self.config.output_attentions: all_attentions = encoder_outputs[-1] - outputs = outputs + (all_attentions,) + outputs = outputs + (all_attentions, ) return outputs + class Graphormer(BertPreTrainedModel): ''' The archtecture of a transformer encoder block we used in Graphormer @@ -323,16 +343,31 @@ class Graphormer(BertPreTrainedModel): self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) self.apply(self.init_weights) - def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, - next_sentence_label=None, position_ids=None, head_mask=None): + def forward( + self, + img_feats, + input_ids=None, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None, + position_ids=None, + head_mask=None + ): ''' # self.bert has three outputs # predictions[0]: output tokens # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" # predictions[2]: attentions, if enable "self.config.output_attentions" ''' - predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, - attention_mask=attention_mask, head_mask=head_mask) + predictions = self.bert( + img_feats=img_feats, + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + head_mask=head_mask + ) # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification. pred_score = self.cls_head(predictions[0]) @@ -344,5 +379,3 @@ class Graphormer(BertPreTrainedModel): return pred_score, predictions[1], predictions[-1] else: return pred_score - - \ No newline at end of file diff --git a/lib/pymafx/models/transformers/bert/modeling_utils.py b/lib/pymafx/models/transformers/bert/modeling_utils.py index 458852810218b7b85f25cb564da8c96886500b7c..40a0915822c8e736de8ac2466c075e6cc5ef7e83 100644 --- a/lib/pymafx/models/transformers/bert/modeling_utils.py +++ b/lib/pymafx/models/transformers/bert/modeling_utils.py @@ -15,8 +15,7 @@ # limitations under the License. """PyTorch BERT model.""" -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import (absolute_import, division, print_function, unicode_literals) import copy import json @@ -38,7 +37,6 @@ CONFIG_NAME = "config.json" WEIGHTS_NAME = "pytorch_model.bin" TF_WEIGHTS_NAME = 'model.ckpt' - try: from torch.nn import Identity except ImportError: @@ -54,16 +52,19 @@ except ImportError: if not six.PY2: + def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = ''.join(docstr) + fn.__doc__ return fn + return docstring_decorator else: # Not possible to update class docstrings on python2 def add_start_docstrings(*docstr): def docstring_decorator(fn): return fn + return docstring_decorator @@ -84,7 +85,9 @@ class PretrainedConfig(object): """ Save a configuration object to a directory, so that it can be re-loaded using the `from_pretrained(save_directory)` class method. """ - assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" + assert os.path.isdir( + save_directory + ), "Saving path should be a directory where the model and configuration can be saved" # If we save using the predefined names, we can load using `from_pretrained` output_config_file = os.path.join(save_directory, CONFIG_NAME) @@ -145,23 +148,28 @@ class PretrainedConfig(object): except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: logger.error( - "Couldn't reach server at '{}' to download pretrained model configuration file.".format( - config_file)) + "Couldn't reach server at '{}' to download pretrained model configuration file." + .format(config_file) + ) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, - ', '.join(cls.pretrained_config_archive_map.keys()), - config_file)) + ', '.join(cls.pretrained_config_archive_map.keys()), config_file + ) + ) return None if resolved_config_file == config_file: pass # logger.info("loading configuration file {}".format(config_file)) else: - logger.info("loading configuration file {} from cache at {}".format( - config_file, resolved_config_file)) + logger.info( + "loading configuration file {} from cache at {}".format( + config_file, resolved_config_file + ) + ) # Load config config = cls.from_json_file(resolved_config_file) @@ -235,7 +243,8 @@ class PreTrainedModel(nn.Module): "To create a model from a pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ - )) + ) + ) # Save config in model self.config = config @@ -269,7 +278,8 @@ class PreTrainedModel(nn.Module): # Copy word embeddings from the previous weights num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] + new_embeddings.weight.data[:num_tokens_to_copy, : + ] = old_embeddings.weight.data[:num_tokens_to_copy, :] return new_embeddings @@ -295,7 +305,7 @@ class PreTrainedModel(nn.Module): Return: ``torch.nn.Embeddings`` Pointer to the input tokens Embedding Module of the model """ - base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed model_embeds = base_model._resize_token_embeddings(new_num_tokens) if new_num_tokens is None: return model_embeds @@ -315,14 +325,16 @@ class PreTrainedModel(nn.Module): Args: heads_to_prune: dict of {layer_num (int): list of heads to prune in this layer (list of int)} """ - base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed base_model._prune_heads(heads_to_prune) def save_pretrained(self, save_directory): """ Save a model with its configuration file to a directory, so that it can be re-loaded using the `from_pretrained(save_directory)` class method. """ - assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" + assert os.path.isdir( + save_directory + ), "Saving path should be a directory where the model and configuration can be saved" # Only save the model it-self if we are using distributed training model_to_save = self.module if hasattr(self, 'module') else self @@ -402,8 +414,10 @@ class PreTrainedModel(nn.Module): # Load config if config is None: config, model_kwargs = cls.config_class.from_pretrained( - pretrained_model_name_or_path, *model_args, - cache_dir=cache_dir, return_unused_kwargs=True, + pretrained_model_name_or_path, + *model_args, + cache_dir=cache_dir, + return_unused_kwargs=True, **kwargs ) else: @@ -415,7 +429,9 @@ class PreTrainedModel(nn.Module): elif os.path.isdir(pretrained_model_name_or_path): if from_tf: # Directly load from a TensorFlow checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") + archive_file = os.path.join( + pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index" + ) else: archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: @@ -430,22 +446,27 @@ class PreTrainedModel(nn.Module): except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error( - "Couldn't reach server at '{}' to download pretrained weights.".format( - archive_file)) + "Couldn't reach server at '{}' to download pretrained weights.". + format(archive_file) + ) else: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, - ', '.join(cls.pretrained_model_archive_map.keys()), - archive_file)) + ', '.join(cls.pretrained_model_archive_map.keys()), archive_file + ) + ) return None if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: - logger.info("loading weights file {} from cache at {}".format( - archive_file, resolved_archive_file)) + logger.info( + "loading weights file {} from cache at {}".format( + archive_file, resolved_archive_file + ) + ) # Instantiate model. model = cls(config, *model_args, **model_kwargs) @@ -454,7 +475,9 @@ class PreTrainedModel(nn.Module): state_dict = torch.load(resolved_archive_file, map_location='cpu') if from_tf: # Directly load from a TensorFlow checkpoint - return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' + return cls.load_tf_weights( + model, config, resolved_archive_file[:-6] + ) # Remove the '.index' # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -484,7 +507,8 @@ class PreTrainedModel(nn.Module): def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') @@ -492,30 +516,46 @@ class PreTrainedModel(nn.Module): # Make sure we are able to load base models as well as derived models (with heads) start_prefix = '' model_to_load = model - if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): + if not hasattr(model, cls.base_model_prefix) and any( + s.startswith(cls.base_model_prefix) for s in state_dict.keys() + ): start_prefix = cls.base_model_prefix + '.' - if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): + if hasattr(model, cls.base_model_prefix + ) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): model_to_load = getattr(model, cls.base_model_prefix) load(model_to_load, prefix=start_prefix) if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) + logger.info( + "Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys + ) + ) if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, unexpected_keys)) + logger.info( + "Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys + ) + ) if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) + raise RuntimeError( + 'Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs) + ) + ) if hasattr(model, 'tie_weights'): - model.tie_weights() # make sure word embedding weights are still tied + model.tie_weights() # make sure word embedding weights are still tied # Set model in evaluation mode to desactivate DropOut modules by default model.eval() if output_loading_info: - loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "error_msgs": error_msgs + } return model, loading_info return model @@ -534,7 +574,7 @@ class Conv1D(nn.Module): self.bias = nn.Parameter(torch.zeros(nf)) def forward(self, x): - size_out = x.size()[:-1] + (self.nf,) + size_out = x.size()[:-1] + (self.nf, ) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(*size_out) return x @@ -586,9 +626,10 @@ class PoolerEndLogits(nn.Module): assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" if start_positions is not None: slen, hsz = hidden_states.shape[-2:] - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) - start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) - start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + start_positions = start_positions[:, None, + None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) x = self.activation(x) @@ -629,14 +670,16 @@ class PoolerAnswerClass(nn.Module): hsz = hidden_states.shape[-1] assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" if start_positions is not None: - start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) - start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + start_positions = start_positions[:, None, + None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, + start_positions).squeeze(-2) # shape (bsz, hsz) if cls_index is not None: - cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) - cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) else: - cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) x = self.activation(x) @@ -694,8 +737,15 @@ class SQuADHead(nn.Module): self.end_logits = PoolerEndLogits(config) self.answer_class = PoolerAnswerClass(config) - def forward(self, hidden_states, start_positions=None, end_positions=None, - cls_index=None, is_impossible=None, p_mask=None): + def forward( + self, + hidden_states, + start_positions=None, + end_positions=None, + cls_index=None, + is_impossible=None, + p_mask=None + ): outputs = () start_logits = self.start_logits(hidden_states, p_mask=p_mask) @@ -707,7 +757,9 @@ class SQuADHead(nn.Module): x.squeeze_(-1) # during training, compute the end logits based on the ground truth of the start position - end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + end_logits = self.end_logits( + hidden_states, start_positions=start_positions, p_mask=p_mask + ) loss_fct = CrossEntropyLoss() start_loss = loss_fct(start_logits, start_positions) @@ -716,38 +768,58 @@ class SQuADHead(nn.Module): if cls_index is not None and is_impossible is not None: # Predict answerability from the representation of CLS and START - cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + cls_logits = self.answer_class( + hidden_states, start_positions=start_positions, cls_index=cls_index + ) loss_fct_cls = nn.BCEWithLogitsLoss() cls_loss = loss_fct_cls(cls_logits, is_impossible) # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss total_loss += cls_loss * 0.5 - outputs = (total_loss,) + outputs + outputs = (total_loss, ) + outputs else: # during inference, compute the end logits based on beam search bsz, slen, hsz = hidden_states.size() - start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) - - start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) - start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) - start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) - start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) - - hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) + start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand( + -1, -1, hsz + ) # shape (bsz, start_n_top, hsz) + start_states = torch.gather( + hidden_states, -2, start_top_index_exp + ) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand( + -1, slen, -1, -1 + ) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None - end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) - end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + end_logits = self.end_logits( + hidden_states_expanded, start_states=start_states, p_mask=p_mask + ) + end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) - end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) - cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + cls_logits = self.answer_class( + hidden_states, start_states=start_states, cls_index=cls_index + ) - outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs + outputs = ( + start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits + ) + outputs # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits # or (if labels are provided) (total_loss,) @@ -781,7 +853,9 @@ class SequenceSummary(nn.Module): self.summary = Identity() if hasattr(config, 'summary_use_proj') and config.summary_use_proj: - if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: + if hasattr( + config, 'summary_proj_to_labels' + ) and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size @@ -814,12 +888,17 @@ class SequenceSummary(nn.Module): output = hidden_states.mean(dim=1) elif self.summary_type == 'token_ids': if token_ids is None: - token_ids = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long) + token_ids = torch.full_like( + hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long + ) else: token_ids = token_ids.unsqueeze(-1).unsqueeze(-1) - token_ids = token_ids.expand((-1,) * (token_ids.dim()-1) + (hidden_states.size(-1),)) + token_ids = token_ids.expand( + (-1, ) * (token_ids.dim() - 1) + (hidden_states.size(-1), ) + ) # shape of token_ids: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states - output = hidden_states.gather(-2, token_ids).squeeze(-2) # shape (bsz, XX, hidden_size) + output = hidden_states.gather(-2, + token_ids).squeeze(-2) # shape (bsz, XX, hidden_size) elif self.summary_type == 'attn': raise NotImplementedError @@ -845,7 +924,8 @@ def prune_linear_layer(layer, index, dim=0): b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) - new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) + new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias + is not None).to(layer.weight.device) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True diff --git a/lib/pymafx/models/transformers/net_utils.py b/lib/pymafx/models/transformers/net_utils.py index 3e29bb1f1f0b9428a0eeb0eeb4eb432db190a5e8..52782911e276705ec0dd908ce9676430c0a58d72 100644 --- a/lib/pymafx/models/transformers/net_utils.py +++ b/lib/pymafx/models/transformers/net_utils.py @@ -7,22 +7,24 @@ import torch.nn.functional as F class single_conv(nn.Module): def __init__(self, in_ch, out_ch): super(single_conv, self).__init__() - self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True),) + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) def forward(self, x): return self.conv(x) + class double_conv(nn.Module): def __init__(self, in_ch, out_ch): super(double_conv, self).__init__() - self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True)) + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) + ) def forward(self, x): return self.conv(x) @@ -31,12 +33,11 @@ class double_conv(nn.Module): class double_conv_down(nn.Module): def __init__(self, in_ch, out_ch): super(double_conv_down, self).__init__() - self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True)) + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) + ) def forward(self, x): return self.conv(x) @@ -45,13 +46,12 @@ class double_conv_down(nn.Module): class double_conv_up(nn.Module): def __init__(self, in_ch, out_ch): super(double_conv_up, self).__init__() - self.conv = nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2), - nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True)) + self.conv = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) + ) def forward(self, x): return self.conv(x) @@ -87,31 +87,35 @@ class PosEnSine(nn.Module): x_embed = x_embed / (torch.max(x_embed) + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_z = z_embed[:, :, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos_z = torch.stack((pos_z[:, :, :, 0::2].sin(), pos_z[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_z = torch.stack((pos_z[:, :, :, 0::2].sin(), pos_z[:, :, :, 1::2].cos()), + dim=4).flatten(3) pos = torch.cat((pos_x, pos_y, pos_z), dim=3).permute(0, 3, 1, 2) # if pt_coord is None: pos = pos.repeat(b, 1, 1, 1) return pos + def softmax_attention(q, k, v): # b x n x d x h x w h, w = q.shape[-2], q.shape[-1] - q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d - k = k.flatten(-2) # b x n x d x hw + q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d + k = k.flatten(-2) # b x n x d x hw v = v.flatten(-2).transpose(-2, -1) print('softmax', q.shape, k.shape, v.shape) - N = k.shape[-1] # ?????? maybe change to k.shape[-2]???? - attn = torch.matmul(q / N ** 0.5, k) + N = k.shape[-1] # ?????? maybe change to k.shape[-2]???? + attn = torch.matmul(q / N**0.5, k) attn = F.softmax(attn, dim=-1) output = torch.matmul(attn, v) @@ -125,8 +129,8 @@ def dotproduct_attention(q, k, v): # b x n x d x h x w h, w = q.shape[-2], q.shape[-1] - q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d - k = k.flatten(-2) # b x n x d x hw + q = q.flatten(-2).transpose(-2, -1) # b x n x hw x d + k = k.flatten(-2) # b x n x d x hw v = v.flatten(-2).transpose(-2, -1) N = k.shape[-1] @@ -140,7 +144,7 @@ def dotproduct_attention(q, k, v): return output, attn -def long_range_attention(q, k, v, P_h, P_w): # fixed patch size +def long_range_attention(q, k, v, P_h, P_w): # fixed patch size B, N, C, qH, qW = q.size() _, _, _, kH, kW = k.size() @@ -151,17 +155,17 @@ def long_range_attention(q, k, v, P_h, P_w): # fixed patch size k = k.reshape(B, N, C, kQ_h, P_h, kQ_w, P_w) v = v.reshape(B, N, -1, kQ_h, P_h, kQ_w, P_w) - q = q.permute(0, 1, 4, 6, 2, 3, 5) # [b, n, Ph, Pw, d, Qh, Qw] + q = q.permute(0, 1, 4, 6, 2, 3, 5) # [b, n, Ph, Pw, d, Qh, Qw] k = k.permute(0, 1, 4, 6, 2, 3, 5) v = v.permute(0, 1, 4, 6, 2, 3, 5) - output, attn = softmax_attention(q, k, v) # attn: [b, n, Ph, Pw, qQh*qQw, kQ_h*kQ_w] + output, attn = softmax_attention(q, k, v) # attn: [b, n, Ph, Pw, qQh*qQw, kQ_h*kQ_w] output = output.permute(0, 1, 4, 5, 2, 6, 3) output = output.reshape(B, N, -1, qH, qW) return output, attn -def short_range_attention(q, k, v, Q_h, Q_w): # fixed patch number +def short_range_attention(q, k, v, Q_h, Q_w): # fixed patch number B, N, C, qH, qW = q.size() _, _, _, kH, kW = k.size() @@ -172,11 +176,11 @@ def short_range_attention(q, k, v, Q_h, Q_w): # fixed patch number k = k.reshape(B, N, C, Q_h, kP_h, Q_w, kP_w) v = v.reshape(B, N, -1, Q_h, kP_h, Q_w, kP_w) - q = q.permute(0, 1, 3, 5, 2, 4, 6) # [b, n, Qh, Qw, d, Ph, Pw] + q = q.permute(0, 1, 3, 5, 2, 4, 6) # [b, n, Qh, Qw, d, Ph, Pw] k = k.permute(0, 1, 3, 5, 2, 4, 6) v = v.permute(0, 1, 3, 5, 2, 4, 6) - output, attn = softmax_attention(q, k, v) # attn: [b, n, Qh, Qw, qPh*qPw, kPh*kPw] + output, attn = softmax_attention(q, k, v) # attn: [b, n, Qh, Qw, qPh*qPw, kPh*kPw] output = output.permute(0, 1, 4, 2, 5, 3, 6) output = output.reshape(B, N, -1, qH, qW) return output, attn @@ -188,7 +192,7 @@ def space_to_depth(x, block_size): if len(x.shape) >= 5: x = x.view(-1, c, h, w) unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size) - return unfolded_x.view(*x_shape[0:-3], c * block_size ** 2, h // block_size, w // block_size) + return unfolded_x.view(*x_shape[0:-3], c * block_size**2, h // block_size, w // block_size) def depth_to_space(x, block_size): @@ -196,17 +200,17 @@ def depth_to_space(x, block_size): c, h, w = x_shape[-3:] x = x.view(-1, c, h, w) y = torch.nn.functional.pixel_shuffle(x, block_size) - return y.view(*x_shape[0:-3], -1, h*block_size, w*block_size) + return y.view(*x_shape[0:-3], -1, h * block_size, w * block_size) def patch_attention(q, k, v, P): # q: [b, nhead, c, h, w] - q_patch = space_to_depth(q, P) # [b, nhead, cP^2, h/P, w/P] + q_patch = space_to_depth(q, P) # [b, nhead, cP^2, h/P, w/P] k_patch = space_to_depth(k, P) v_patch = space_to_depth(v, P) # output: [b, nhead, cP^2, h/P, w/P] # attn: [b, nhead, h/P*w/P, h/P*w/P] - output, attn = softmax_attention(q_patch, k_patch, v_patch) - output = depth_to_space(output, P) # output: [b, nhead, c, h, w] + output, attn = softmax_attention(q_patch, k_patch, v_patch) + output = depth_to_space(output, P) # output: [b, nhead, c, h, w] return output, attn diff --git a/lib/pymafx/models/transformers/texformer.py b/lib/pymafx/models/transformers/texformer.py index 4554ce9629f6804e6df82c097f0774905fe332bf..4266b24ed6839f91ce5ca819cc3750143387d48f 100644 --- a/lib/pymafx/models/transformers/texformer.py +++ b/lib/pymafx/models/transformers/texformer.py @@ -2,6 +2,7 @@ import torch.nn as nn from .net_utils import single_conv, double_conv, double_conv_down, double_conv_up, PosEnSine from .transformer_basics import OurMultiheadAttention + class TransformerDecoderUnit(nn.Module): def __init__(self, feat_dim, n_head=8, pos_en_flag=True, attn_type='softmax', P=None): super(TransformerDecoderUnit, self).__init__() @@ -11,8 +12,8 @@ class TransformerDecoderUnit(nn.Module): self.P = P self.pos_en = PosEnSine(self.feat_dim // 2) - self.attn = OurMultiheadAttention(feat_dim, n_head) # cross-attention - + self.attn = OurMultiheadAttention(feat_dim, n_head) # cross-attention + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.activation = nn.ReLU(inplace=True) @@ -28,7 +29,9 @@ class TransformerDecoderUnit(nn.Module): k_pos_embed = 0 # cross-multi-head attention - out = self.attn(q=q+q_pos_embed, k=k+k_pos_embed, v=v, attn_type=self.attn_type, P=self.P)[0] + out = self.attn( + q=q + q_pos_embed, k=k + k_pos_embed, v=v, attn_type=self.attn_type, P=self.P + )[0] # feed forward out2 = self.linear2(self.activation(self.linear1(out))) @@ -52,17 +55,17 @@ class Unet(nn.Module): def forward(self, x): feat0 = self.conv_in(x) # H - feat1 = self.conv1(feat0) # H/2 + feat1 = self.conv1(feat0) # H/2 feat2 = self.conv2(feat1) # H/4 feat3 = self.conv3(feat2) # H/4 - feat3 = feat3 + feat2 # H/4 + feat3 = feat3 + feat2 # H/4 feat4 = self.conv4(feat3) # H/2 feat4 = feat4 + feat1 # H/2 - feat5 = self.conv5(feat4) # H - feat5 = feat5 + feat0 # H + feat5 = self.conv5(feat4) # H + feat5 = feat5 + feat0 # H feat6 = self.conv6(feat5) - return feat0, feat1, feat2, feat3, feat4, feat6 + return feat0, feat1, feat2, feat3, feat4, feat6 class Texformer(nn.Module): @@ -77,18 +80,20 @@ class Texformer(nn.Module): if not self.mask_fusion: v_ch = out_ch else: - v_ch = 2+3 + v_ch = 2 + 3 self.unet_q = Unet(tgt_ch, self.feat_dim, self.feat_dim) self.unet_k = Unet(src_ch, self.feat_dim, self.feat_dim) self.unet_v = Unet(v_ch, self.feat_dim, self.feat_dim) - self.trans_dec = nn.ModuleList([None, - None, - None, - TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'), - TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'), - TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct')]) + self.trans_dec = nn.ModuleList( + [ + None, None, None, + TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'softmax'), + TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct'), + TransformerDecoderUnit(self.feat_dim, opts.nhead, True, 'dotproduct') + ] + ) self.conv0 = double_conv(self.feat_dim, self.feat_dim) self.conv1 = double_conv_down(self.feat_dim, self.feat_dim) @@ -96,13 +101,17 @@ class Texformer(nn.Module): self.conv3 = double_conv(self.feat_dim, self.feat_dim) self.conv4 = double_conv_up(self.feat_dim, self.feat_dim) self.conv5 = double_conv_up(self.feat_dim, self.feat_dim) - + if not self.mask_fusion: - self.conv6 = nn.Sequential(single_conv(self.feat_dim, self.feat_dim), - nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1)) + self.conv6 = nn.Sequential( + single_conv(self.feat_dim, self.feat_dim), + nn.Conv2d(self.feat_dim, out_ch, 3, 1, 1) + ) else: - self.conv6 = nn.Sequential(single_conv(self.feat_dim, self.feat_dim), - nn.Conv2d(self.feat_dim, 2+3+1, 3, 1, 1)) # mask*flow-sampling + (1-mask)*rgb + self.conv6 = nn.Sequential( + single_conv(self.feat_dim, self.feat_dim), + nn.Conv2d(self.feat_dim, 2 + 3 + 1, 3, 1, 1) + ) # mask*flow-sampling + (1-mask)*rgb self.sigmoid = nn.Sigmoid() self.tanh = nn.Tanh() @@ -120,16 +129,16 @@ class Texformer(nn.Module): outputs.append(self.trans_dec[i](q_feat[i], k_feat[i], v_feat[i])) print('outputs', outputs[-1].shape) - f0 = self.conv0(outputs[2]) # H - f1 = self.conv1(f0) # H/2 + f0 = self.conv0(outputs[2]) # H + f1 = self.conv1(f0) # H/2 f1 = f1 + outputs[1] - f2 = self.conv2(f1) # H/4 + f2 = self.conv2(f1) # H/4 f2 = f2 + outputs[0] - f3 = self.conv3(f2) # H/4 - f3 = f3 + outputs[0] + f2 - f4 = self.conv4(f3) # H/2 + f3 = self.conv3(f2) # H/4 + f3 = f3 + outputs[0] + f2 + f4 = self.conv4(f3) # H/2 f4 = f4 + outputs[1] + f1 - f5 = self.conv5(f4) # H + f5 = self.conv5(f4) # H f5 = f5 + outputs[2] + f0 if not self.mask_fusion: out = self.tanh(self.conv6(f5)) @@ -137,4 +146,3 @@ class Texformer(nn.Module): out_ = self.conv6(f5) out = [self.tanh(out_[:, :2]), self.tanh(out_[:, 2:5]), self.sigmoid(out_[:, 5:])] return out - diff --git a/lib/pymafx/models/transformers/tokenlearner.py b/lib/pymafx/models/transformers/tokenlearner.py index 5127fa57e7350daac11ed4e0fde34748eddbbd1f..441b361a721f685f481e764c19b624b593124c1b 100644 --- a/lib/pymafx/models/transformers/tokenlearner.py +++ b/lib/pymafx/models/transformers/tokenlearner.py @@ -2,44 +2,45 @@ import torch import torch.nn as nn import torch.nn.functional as F + class SpatialAttention(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Sequential( - nn.Conv2d(2, 1, kernel_size=(1,1), stride=1), - nn.BatchNorm2d(1), - nn.ReLU() + nn.Conv2d(2, 1, kernel_size=(1, 1), stride=1), nn.BatchNorm2d(1), nn.ReLU() ) - + self.sgap = nn.AvgPool2d(2) def forward(self, x): B, H, W, C = x.shape x = x.reshape(B, C, H, W) - + mx = torch.max(x, 1)[0].unsqueeze(1) avg = torch.mean(x, 1).unsqueeze(1) combined = torch.cat([mx, avg], dim=1) fmap = self.conv(combined) weight_map = torch.sigmoid(fmap) out = (x * weight_map).mean(dim=(-2, -1)) - + return out, x * weight_map + class TokenLearner(nn.Module): def __init__(self, S) -> None: super().__init__() self.S = S self.tokenizers = nn.ModuleList([SpatialAttention() for _ in range(S)]) - + def forward(self, x): B, _, _, C = x.shape Z = torch.Tensor(B, self.S, C).to(x) for i in range(self.S): - Ai, _ = self.tokenizers[i](x) # [B, C] + Ai, _ = self.tokenizers[i](x) # [B, C] Z[:, i, :] = Ai return Z + class TokenFuser(nn.Module): def __init__(self, H, W, C, S) -> None: super().__init__() @@ -47,18 +48,18 @@ class TokenFuser(nn.Module): self.Bi = nn.Linear(C, S) self.spatial_attn = SpatialAttention() self.S = S - + def forward(self, y, x): B, S, C = y.shape B, H, W, C = x.shape - + Y = self.projection(y.reshape(B, C, S)).reshape(B, S, C) - Bw = torch.sigmoid(self.Bi(x)).reshape(B, H*W, S) # [B, HW, S] + Bw = torch.sigmoid(self.Bi(x)).reshape(B, H * W, S) # [B, HW, S] BwY = torch.matmul(Bw, Y) - + _, xj = self.spatial_attn(x) - xj = xj.reshape(B, H*W, C) + xj = xj.reshape(B, H * W, C) out = (BwY + xj).reshape(B, H, W, C) - - return out \ No newline at end of file + + return out diff --git a/lib/pymafx/models/transformers/transformer_basics.py b/lib/pymafx/models/transformers/transformer_basics.py index 05c26c1639b5e5a0e4f68d782a08041024d302ff..144ccd76b7e2f73189634ab551691c4262781b9d 100644 --- a/lib/pymafx/models/transformers/transformer_basics.py +++ b/lib/pymafx/models/transformers/transformer_basics.py @@ -35,7 +35,7 @@ class OurMultiheadAttention(nn.Module): # -------------- Attention ----------------- if attn_type == 'softmax': - q, attn = softmax_attention(q, k, v) # b x n x dk x h x w --> b x n x dv x h x w + q, attn = softmax_attention(q, k, v) # b x n x dk x h x w --> b x n x dv x h x w elif attn_type == 'dotproduct': q, attn = dotproduct_attention(q, k, v) elif attn_type == 'patch': @@ -50,7 +50,7 @@ class OurMultiheadAttention(nn.Module): # Concatenate all the heads together: b x (n*dv) x h x w q = q.reshape(q.shape[0], -1, q.shape[3], q.shape[4]) - q = self.fc(q) # b x d x h x w + q = self.fc(q) # b x d x h x w return q, attn @@ -65,22 +65,24 @@ class TransformerEncoderUnit(nn.Module): self.pos_en = PosEnSine(self.feat_dim // 2) self.attn = OurMultiheadAttention(feat_dim, n_head) - + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.activation = nn.ReLU(inplace=True) self.norm1 = nn.BatchNorm2d(self.feat_dim) - self.norm2 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) def forward(self, src): if self.pos_en_flag: pos_embed = self.pos_en(src) else: pos_embed = 0 - + # multi-head attention - src2 = self.attn(q=src+pos_embed, k=src+pos_embed, v=src, attn_type=self.attn_type, P=self.P)[0] + src2 = self.attn( + q=src + pos_embed, k=src + pos_embed, v=src, attn_type=self.attn_type, P=self.P + )[0] src = src + src2 src = self.norm1(src) @@ -102,26 +104,40 @@ class TransformerEncoderUnitSparse(nn.Module): self.pos_en = PosEnSine(self.feat_dim // 2) self.attn1 = OurMultiheadAttention(feat_dim, n_head) # long range self.attn2 = OurMultiheadAttention(feat_dim, n_head) # short range - + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.activation = nn.ReLU(inplace=True) self.norm1 = nn.BatchNorm2d(self.feat_dim) - self.norm2 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) def forward(self, src): if self.pos_en_flag: pos_embed = self.pos_en(src) else: pos_embed = 0 - + # multi-head long-range attention - src2 = self.attn1(q=src+pos_embed, k=src+pos_embed, v=src, attn_type='sparse_long', ah=self.ahw[0], aw=self.ahw[1])[0] + src2 = self.attn1( + q=src + pos_embed, + k=src + pos_embed, + v=src, + attn_type='sparse_long', + ah=self.ahw[0], + aw=self.ahw[1] + )[0] src = src + src2 # ? this might be ok to remove - + # multi-head short-range attention - src2 = self.attn2(q=src+pos_embed, k=src+pos_embed, v=src, attn_type='sparse_short', ah=self.ahw[2], aw=self.ahw[3])[0] + src2 = self.attn2( + q=src + pos_embed, + k=src + pos_embed, + v=src, + attn_type='sparse_short', + ah=self.ahw[2], + aw=self.ahw[3] + )[0] src = src + src2 src = self.norm1(src) @@ -142,16 +158,16 @@ class TransformerDecoderUnit(nn.Module): self.P = P self.pos_en = PosEnSine(self.feat_dim // 2) - self.attn1 = OurMultiheadAttention(feat_dim, n_head) # self-attention - self.attn2 = OurMultiheadAttention(feat_dim, n_head) # cross-attention - + self.attn1 = OurMultiheadAttention(feat_dim, n_head) # self-attention + self.attn2 = OurMultiheadAttention(feat_dim, n_head) # cross-attention + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.activation = nn.ReLU(inplace=True) self.norm1 = nn.BatchNorm2d(self.feat_dim) - self.norm2 = nn.BatchNorm2d(self.feat_dim) - self.norm3 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) + self.norm3 = nn.BatchNorm2d(self.feat_dim) def forward(self, tgt, src): if self.pos_en_flag: @@ -160,14 +176,18 @@ class TransformerDecoderUnit(nn.Module): else: src_pos_embed = 0 tgt_pos_embed = 0 - + # self-multi-head attention - tgt2 = self.attn1(q=tgt+tgt_pos_embed, k=tgt+tgt_pos_embed, v=tgt, attn_type=self.attn_type, P=self.P)[0] + tgt2 = self.attn1( + q=tgt + tgt_pos_embed, k=tgt + tgt_pos_embed, v=tgt, attn_type=self.attn_type, P=self.P + )[0] tgt = tgt + tgt2 tgt = self.norm1(tgt) # cross-multi-head attention - tgt2 = self.attn2(q=tgt+tgt_pos_embed, k=src+src_pos_embed, v=src, attn_type=self.attn_type, P=self.P)[0] + tgt2 = self.attn2( + q=tgt + tgt_pos_embed, k=src + src_pos_embed, v=src, attn_type=self.attn_type, P=self.P + )[0] tgt = tgt + tgt2 tgt = self.norm2(tgt) @@ -183,23 +203,25 @@ class TransformerDecoderUnitSparse(nn.Module): def __init__(self, feat_dim, n_head=8, pos_en_flag=True, ahw=None): super(TransformerDecoderUnitSparse, self).__init__() self.feat_dim = feat_dim - self.ahw = ahw # [Ph_tgt, Pw_tgt, Qh_tgt, Qw_tgt, Ph_src, Pw_src, Qh_tgt, Qw_tgt] + self.ahw = ahw # [Ph_tgt, Pw_tgt, Qh_tgt, Qw_tgt, Ph_src, Pw_src, Qh_tgt, Qw_tgt] self.pos_en_flag = pos_en_flag self.pos_en = PosEnSine(self.feat_dim // 2) - self.attn1_1 = OurMultiheadAttention(feat_dim, n_head) # self-attention: long - self.attn1_2 = OurMultiheadAttention(feat_dim, n_head) # self-attention: short + self.attn1_1 = OurMultiheadAttention(feat_dim, n_head) # self-attention: long + self.attn1_2 = OurMultiheadAttention(feat_dim, n_head) # self-attention: short - self.attn2_1 = OurMultiheadAttention(feat_dim, n_head) # cross-attention: self-attention-long + cross-attention-short + self.attn2_1 = OurMultiheadAttention( + feat_dim, n_head + ) # cross-attention: self-attention-long + cross-attention-short self.attn2_2 = OurMultiheadAttention(feat_dim, n_head) - + self.linear1 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.linear2 = nn.Conv2d(self.feat_dim, self.feat_dim, 1) self.activation = nn.ReLU(inplace=True) self.norm1 = nn.BatchNorm2d(self.feat_dim) - self.norm2 = nn.BatchNorm2d(self.feat_dim) - self.norm3 = nn.BatchNorm2d(self.feat_dim) + self.norm2 = nn.BatchNorm2d(self.feat_dim) + self.norm3 = nn.BatchNorm2d(self.feat_dim) def forward(self, tgt, src): if self.pos_en_flag: @@ -208,20 +230,48 @@ class TransformerDecoderUnitSparse(nn.Module): else: src_pos_embed = 0 tgt_pos_embed = 0 - + # self-multi-head attention: sparse long - tgt2 = self.attn1_1(q=tgt+tgt_pos_embed, k=tgt+tgt_pos_embed, v=tgt, attn_type='sparse_long', ah=self.ahw[0], aw=self.ahw[1])[0] + tgt2 = self.attn1_1( + q=tgt + tgt_pos_embed, + k=tgt + tgt_pos_embed, + v=tgt, + attn_type='sparse_long', + ah=self.ahw[0], + aw=self.ahw[1] + )[0] tgt = tgt + tgt2 # self-multi-head attention: sparse short - tgt2 = self.attn1_2(q=tgt+tgt_pos_embed, k=tgt+tgt_pos_embed, v=tgt, attn_type='sparse_short', ah=self.ahw[2], aw=self.ahw[3])[0] + tgt2 = self.attn1_2( + q=tgt + tgt_pos_embed, + k=tgt + tgt_pos_embed, + v=tgt, + attn_type='sparse_short', + ah=self.ahw[2], + aw=self.ahw[3] + )[0] tgt = tgt + tgt2 tgt = self.norm1(tgt) # self-multi-head attention: sparse long - src2 = self.attn2_1(q=src+src_pos_embed, k=src+src_pos_embed, v=src, attn_type='sparse_long', ah=self.ahw[4], aw=self.ahw[5])[0] + src2 = self.attn2_1( + q=src + src_pos_embed, + k=src + src_pos_embed, + v=src, + attn_type='sparse_long', + ah=self.ahw[4], + aw=self.ahw[5] + )[0] src = src + src2 # cross-multi-head attention: sparse short - tgt2 = self.attn2_2(q=tgt+tgt_pos_embed, k=src+src_pos_embed, v=src, attn_type='sparse_short', ah=self.ahw[6], aw=self.ahw[7])[0] + tgt2 = self.attn2_2( + q=tgt + tgt_pos_embed, + k=src + src_pos_embed, + v=src, + attn_type='sparse_short', + ah=self.ahw[6], + aw=self.ahw[7] + )[0] tgt = tgt + tgt2 tgt = self.norm2(tgt) @@ -231,4 +281,3 @@ class TransformerDecoderUnitSparse(nn.Module): tgt = self.norm3(tgt) return tgt - diff --git a/lib/pymafx/utils/binvox_rw.py b/lib/pymafx/utils/binvox_rw.py index c9c11d6992827ca2132a87599f2042867f77a455..947c3258691da908954f765bde07e0978cfb9f97 100644 --- a/lib/pymafx/utils/binvox_rw.py +++ b/lib/pymafx/utils/binvox_rw.py @@ -16,7 +16,6 @@ # # Modified by Christopher B. Choy # for python 3 support - """ Binvox to Numpy and back. @@ -65,6 +64,7 @@ True import numpy as np + class Voxels(object): """ Holds a binvox model. data is either a three-dimensional numpy boolean array (dense representation) @@ -86,7 +86,6 @@ class Voxels(object): z = scale*z_n + translate[2] """ - def __init__(self, data, dims, translate, scale, axis_order): self.data = data self.dims = dims @@ -104,6 +103,7 @@ class Voxels(object): def write(self, fp): write(self, fp) + def read_header(fp): """ Read binvox header. Mostly meant for internal use. """ @@ -116,6 +116,7 @@ def read_header(fp): line = fp.readline() return dims, translate, scale + def read_as_3d_array(fp, fix_coords=True): """ Read binary binvox format as array. @@ -189,8 +190,8 @@ def read_as_coord_array(fp, fix_coords=True): # according to docs, # index = x * wxh + z * width + y; // wxh = width * height = d * d - x = nz_voxels / (dims[0]*dims[1]) - zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y + x = nz_voxels / (dims[0] * dims[1]) + zwpy = nz_voxels % (dims[0] * dims[1]) # z*w + y z = zwpy / dims[0] y = zwpy % dims[0] if fix_coords: @@ -203,34 +204,38 @@ def read_as_coord_array(fp, fix_coords=True): #return Voxels(data, dims, translate, scale, axis_order) return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) + def dense_to_sparse(voxel_data, dtype=np.int): """ From dense representation to sparse (coordinate) representation. No coordinate reordering. """ - if voxel_data.ndim!=3: + if voxel_data.ndim != 3: raise ValueError('voxel_data is wrong shape; should be 3D array.') return np.asarray(np.nonzero(voxel_data), dtype) + def sparse_to_dense(voxel_data, dims, dtype=np.bool): - if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: + if voxel_data.ndim != 2 or voxel_data.shape[0] != 3: raise ValueError('voxel_data is wrong shape; should be 3xN array.') if np.isscalar(dims): - dims = [dims]*3 + dims = [dims] * 3 dims = np.atleast_2d(dims).T # truncate to integers xyz = voxel_data.astype(np.int) # discard voxels that fall outside dims valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) - xyz = xyz[:,valid_ix] + xyz = xyz[:, valid_ix] out = np.zeros(dims.flatten(), dtype=dtype) out[tuple(xyz)] = True return out + #def get_linear_index(x, y, z, dims): - #""" Assuming xzy order. (y increasing fastest. - #TODO ensure this is right when dims are not all same - #""" - #return x*(dims[1]*dims[2]) + z*dims[1] + y +#""" Assuming xzy order. (y increasing fastest. +#TODO ensure this is right when dims are not all same +#""" +#return x*(dims[1]*dims[2]) + z*dims[1] + y + def write(voxel_model, fp): """ Write binary binvox format. @@ -241,33 +246,33 @@ def write(voxel_model, fp): Doesn't check if the model is 'sane'. """ - if voxel_model.data.ndim==2: + if voxel_model.data.ndim == 2: # TODO avoid conversion to dense dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) else: dense_voxel_data = voxel_model.data fp.write('#binvox 1\n') - fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') - fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') - fp.write('scale '+str(voxel_model.scale)+'\n') + fp.write('dim ' + ' '.join(map(str, voxel_model.dims)) + '\n') + fp.write('translate ' + ' '.join(map(str, voxel_model.translate)) + '\n') + fp.write('scale ' + str(voxel_model.scale) + '\n') fp.write('data\n') if not voxel_model.axis_order in ('xzy', 'xyz'): raise ValueError('Unsupported voxel model axis order') - if voxel_model.axis_order=='xzy': + if voxel_model.axis_order == 'xzy': voxels_flat = dense_voxel_data.flatten() - elif voxel_model.axis_order=='xyz': + elif voxel_model.axis_order == 'xyz': voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() # keep a sort of state machine for writing run length encoding state = voxels_flat[0] ctr = 0 for c in voxels_flat: - if c==state: + if c == state: ctr += 1 # if ctr hits max, dump - if ctr==255: + if ctr == 255: fp.write(chr(state)) fp.write(chr(ctr)) ctr = 0 @@ -282,6 +287,7 @@ def write(voxel_model, fp): fp.write(chr(state)) fp.write(chr(ctr)) + if __name__ == '__main__': import doctest doctest.testmod() diff --git a/lib/pymafx/utils/blob.py b/lib/pymafx/utils/blob.py index 0d989f13c139abe4905280579c083b93b92e68d8..00123338e18a3fa74a6c3cb730cac9fb41b59ac5 100644 --- a/lib/pymafx/utils/blob.py +++ b/lib/pymafx/utils/blob.py @@ -45,9 +45,7 @@ def get_image_blob(im, target_scale, target_max_size): im_scale (float): image scale (target size) / (original size) im_info (ndarray) """ - processed_im, im_scale = prep_im_for_blob( - im, cfg.PIXEL_MEANS, [target_scale], target_max_size - ) + processed_im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, [target_scale], target_max_size) blob = im_list_to_blob(processed_im) # NOTE: this height and width may be larger than actual scaled input image # due to the FPN.COARSEST_STRIDE related padding in im_list_to_blob. We are @@ -76,8 +74,7 @@ def im_list_to_blob(ims): max_shape = get_max_shape([im.shape[:2] for im in ims]) num_images = len(ims) - blob = np.zeros( - (num_images, max_shape[0], max_shape[1], 3), dtype=np.float32) + blob = np.zeros((num_images, max_shape[0], max_shape[1], 3), dtype=np.float32) for i in range(num_images): im = ims[i] blob[i, 0:im.shape[0], 0:im.shape[1], :] = im @@ -119,8 +116,9 @@ def prep_im_for_blob(im, pixel_means, target_sizes, max_size): im_scales = [] for target_size in target_sizes: im_scale = get_target_scale(im_size_min, im_size_max, target_size, max_size) - im_resized = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, - interpolation=cv2.INTER_LINEAR) + im_resized = cv2.resize( + im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR + ) ims.append(im_resized) im_scales.append(im_scale) return ims, im_scales diff --git a/lib/pymafx/utils/cam_params.py b/lib/pymafx/utils/cam_params.py index 7cf877bbf0de1f952d1be3efe200439712b289b5..1f6c1a8d89b2c80d72c90c841d02425df77aa4a5 100644 --- a/lib/pymafx/utils/cam_params.py +++ b/lib/pymafx/utils/cam_params.py @@ -22,6 +22,7 @@ import joblib from .geometry import batch_euler2matrix + def f_pix2vfov(f_pix, img_h): if torch.is_tensor(f_pix): @@ -31,6 +32,7 @@ def f_pix2vfov(f_pix, img_h): return fov + def vfov2f_pix(fov, img_h): if torch.is_tensor(fov): @@ -40,6 +42,7 @@ def vfov2f_pix(fov, img_h): return f_pix + def read_cam_params(cam_params, orig_shape=None): # These are predicted camera parameters # cam_param_folder = CAM_PARAM_FOLDERS[dataset_name][cam_param_type] @@ -69,6 +72,7 @@ def read_cam_params(cam_params, orig_shape=None): return cam_rotmat, cam_int, cam_vfov, cam_pitch, cam_roll, cam_focal_length + def homo_vector(vector): """ vector: B x N x C diff --git a/lib/pymafx/utils/collections.py b/lib/pymafx/utils/collections.py index 465c9df196d762430d2318fc30e85ccd107b8b84..edd20a8c89d5d2221dc9d35948eda12c6304ba29 100644 --- a/lib/pymafx/utils/collections.py +++ b/lib/pymafx/utils/collections.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## - """A simple attribute dictionary used for representing configuration options.""" from __future__ import absolute_import @@ -45,8 +44,7 @@ class AttrDict(dict): self[name] = value else: raise AttributeError( - 'Attempted to set "{}" to "{}", but AttrDict is immutable'. - format(name, value) + 'Attempted to set "{}" to "{}", but AttrDict is immutable'.format(name, value) ) def immutable(self, is_immutable): diff --git a/lib/pymafx/utils/colormap.py b/lib/pymafx/utils/colormap.py index bc6869f289a9c47519ca69bdddba3dd4fa82ea27..44ef28c050021a6f03d088e9437de0c4adeb5ee5 100644 --- a/lib/pymafx/utils/colormap.py +++ b/lib/pymafx/utils/colormap.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## - """An awesome colormap for really neat visualizations.""" from __future__ import absolute_import @@ -26,85 +25,26 @@ import numpy as np def colormap(rgb=False): color_list = np.array( [ - 0.000, 0.447, 0.741, - 0.850, 0.325, 0.098, - 0.929, 0.694, 0.125, - 0.494, 0.184, 0.556, - 0.466, 0.674, 0.188, - 0.301, 0.745, 0.933, - 0.635, 0.078, 0.184, - 0.300, 0.300, 0.300, - 0.600, 0.600, 0.600, - 1.000, 0.000, 0.000, - 1.000, 0.500, 0.000, - 0.749, 0.749, 0.000, - 0.000, 1.000, 0.000, - 0.000, 0.000, 1.000, - 0.667, 0.000, 1.000, - 0.333, 0.333, 0.000, - 0.333, 0.667, 0.000, - 0.333, 1.000, 0.000, - 0.667, 0.333, 0.000, - 0.667, 0.667, 0.000, - 0.667, 1.000, 0.000, - 1.000, 0.333, 0.000, - 1.000, 0.667, 0.000, - 1.000, 1.000, 0.000, - 0.000, 0.333, 0.500, - 0.000, 0.667, 0.500, - 0.000, 1.000, 0.500, - 0.333, 0.000, 0.500, - 0.333, 0.333, 0.500, - 0.333, 0.667, 0.500, - 0.333, 1.000, 0.500, - 0.667, 0.000, 0.500, - 0.667, 0.333, 0.500, - 0.667, 0.667, 0.500, - 0.667, 1.000, 0.500, - 1.000, 0.000, 0.500, - 1.000, 0.333, 0.500, - 1.000, 0.667, 0.500, - 1.000, 1.000, 0.500, - 0.000, 0.333, 1.000, - 0.000, 0.667, 1.000, - 0.000, 1.000, 1.000, - 0.333, 0.000, 1.000, - 0.333, 0.333, 1.000, - 0.333, 0.667, 1.000, - 0.333, 1.000, 1.000, - 0.667, 0.000, 1.000, - 0.667, 0.333, 1.000, - 0.667, 0.667, 1.000, - 0.667, 1.000, 1.000, - 1.000, 0.000, 1.000, - 1.000, 0.333, 1.000, - 1.000, 0.667, 1.000, - 0.167, 0.000, 0.000, - 0.333, 0.000, 0.000, - 0.500, 0.000, 0.000, - 0.667, 0.000, 0.000, - 0.833, 0.000, 0.000, - 1.000, 0.000, 0.000, - 0.000, 0.167, 0.000, - 0.000, 0.333, 0.000, - 0.000, 0.500, 0.000, - 0.000, 0.667, 0.000, - 0.000, 0.833, 0.000, - 0.000, 1.000, 0.000, - 0.000, 0.000, 0.167, - 0.000, 0.000, 0.333, - 0.000, 0.000, 0.500, - 0.000, 0.000, 0.667, - 0.000, 0.000, 0.833, - 0.000, 0.000, 1.000, - 0.000, 0.000, 0.000, - 0.143, 0.143, 0.143, - 0.286, 0.286, 0.286, - 0.429, 0.429, 0.429, - 0.571, 0.571, 0.571, - 0.714, 0.714, 0.714, - 0.857, 0.857, 0.857, - 1.000, 1.000, 1.000 + 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, 0.184, 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, 1.000, 0.500, 0.000, 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, 0.000, 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000, 0.500, 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, 0.333, 1.000, 0.500, 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, 0.667, 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, 1.000, 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, 1.000, 0.000, 1.000, 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, 0.286, 0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000 ] ).astype(np.float32) color_list = color_list.reshape((-1, 3)) * 255 diff --git a/lib/pymafx/utils/common.py b/lib/pymafx/utils/common.py index ee29cc1c3ed765261c265e15b3959ebecb439857..f3330ea18c4783ccacb21657808b8b8ce2301f86 100755 --- a/lib/pymafx/utils/common.py +++ b/lib/pymafx/utils/common.py @@ -4,7 +4,6 @@ import logging from copy import deepcopy from .utils.libkdtree import KDTree - logger_py = logging.getLogger(__name__) @@ -37,6 +36,7 @@ def compute_iou(occ1, occ2): return iou + def rgb2gray(rgb): ''' rgb of size B x h x w x 3 ''' @@ -46,8 +46,9 @@ def rgb2gray(rgb): return gray -def sample_patch_points(batch_size, n_points, patch_size=1, - image_resolution=(128, 128), continuous=True): +def sample_patch_points( + batch_size, n_points, patch_size=1, image_resolution=(128, 128), continuous=True +): ''' Returns sampled points in the range [-1, 1]. Args: @@ -60,21 +61,21 @@ def sample_patch_points(batch_size, n_points, patch_size=1, continuous (bool): whether to sample continuously or only on pixel locations ''' - assert(patch_size > 0) + assert (patch_size > 0) # Calculate step size for [-1, 1] that is equivalent to a pixel in # original resolution h_step = 1. / image_resolution[0] w_step = 1. / image_resolution[1] # Get number of patches - patch_size_squared = patch_size ** 2 + patch_size_squared = patch_size**2 n_patches = int(n_points / patch_size_squared) if continuous: - p = torch.rand(batch_size, n_patches, 2) # [0, 1] + p = torch.rand(batch_size, n_patches, 2) # [0, 1] else: - px = torch.randint(0, image_resolution[1], size=( - batch_size, n_patches, 1)).float() / (image_resolution[1] - 1) - py = torch.randint(0, image_resolution[0], size=( - batch_size, n_patches, 1)).float() / (image_resolution[0] - 1) + px = torch.randint(0, image_resolution[1], + size=(batch_size, n_patches, 1)).float() / (image_resolution[1] - 1) + py = torch.randint(0, image_resolution[0], + size=(batch_size, n_patches, 1)).float() / (image_resolution[0] - 1) p = torch.cat([px, py], dim=-1) # Scale p to [0, (1 - (patch_size - 1) * step) ] p[:, :, 0] *= 1 - (patch_size - 1) * w_step @@ -83,9 +84,8 @@ def sample_patch_points(batch_size, n_points, patch_size=1, # Add points patch_arange = torch.arange(patch_size) x_offset, y_offset = torch.meshgrid(patch_arange, patch_arange) - patch_offsets = torch.stack( - [x_offset.reshape(-1), y_offset.reshape(-1)], - dim=1).view(1, 1, -1, 2).repeat(batch_size, n_patches, 1, 1).float() + patch_offsets = torch.stack([x_offset.reshape(-1), y_offset.reshape(-1)], + dim=1).view(1, 1, -1, 2).repeat(batch_size, n_patches, 1, 1).float() patch_offsets[:, :, :, 0] *= w_step patch_offsets[:, :, :, 1] *= h_step @@ -99,13 +99,12 @@ def sample_patch_points(batch_size, n_points, patch_size=1, p = p.view(batch_size, -1, 2) amax, amin = p.max(), p.min() - assert(amax <= 1. and amin >= -1.) + assert (amax <= 1. and amin >= -1.) return p -def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1, - eps=1e-6, n_steps=40): +def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, n_steps=40): ''' Returns n_steps equally spaced points inside the unit cube on the rays cast from ray0 with direction ray_direction. @@ -138,8 +137,7 @@ def get_proposal_points_in_unit_cube(ray0, ray_direction, padding=0.1, return d_proposal, mask_inside_cube -def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, - eps=1e-6, scale=2.0): +def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, eps=1e-6, scale=2.0): ''' Checks if rays ray0 + d * ray_direction intersect with unit cube with padding padding. @@ -160,7 +158,7 @@ def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, # d = - / # Get points on plane p_e - p_distance = (scale * 0.5) + padding/2 + p_distance = (scale * 0.5) + padding / 2 p_e = torch.ones(batch_size, n_pts, 6).to(device) * p_distance p_e[:, :, 3:] *= -1. @@ -185,35 +183,32 @@ def check_ray_intersection_with_unit_cube(ray0, ray_direction, padding=0.1, mask_inside_cube = p_mask_inside_cube.sum(-1) == 2 # Get interval values for p's which are valid - p_intervals = p_intersect[mask_inside_cube][p_mask_inside_cube[ - mask_inside_cube]].view(-1, 2, 3) + p_intervals = p_intersect[mask_inside_cube][p_mask_inside_cube[mask_inside_cube]].view(-1, 2, 3) p_intervals_batch = torch.zeros(batch_size, n_pts, 2, 3).to(device) p_intervals_batch[mask_inside_cube] = p_intervals # Calculate ray lengths for the interval points d_intervals_batch = torch.zeros(batch_size, n_pts, 2).to(device) norm_ray = torch.norm(ray_direction[mask_inside_cube], dim=-1) - d_intervals_batch[mask_inside_cube] = torch.stack([ - torch.norm(p_intervals[:, 0] - - ray0[mask_inside_cube], dim=-1) / norm_ray, - torch.norm(p_intervals[:, 1] - - ray0[mask_inside_cube], dim=-1) / norm_ray, - ], dim=-1) + d_intervals_batch[mask_inside_cube] = torch.stack( + [ + torch.norm(p_intervals[:, 0] - ray0[mask_inside_cube], dim=-1) / norm_ray, + torch.norm(p_intervals[:, 1] - ray0[mask_inside_cube], dim=-1) / norm_ray, + ], + dim=-1 + ) # Sort the ray lengths d_intervals_batch, indices_sort = d_intervals_batch.sort() - p_intervals_batch = p_intervals_batch[ - torch.arange(batch_size).view(-1, 1, 1), - torch.arange(n_pts).view(1, -1, 1), - indices_sort - ] + p_intervals_batch = p_intervals_batch[torch.arange(batch_size).view(-1, 1, 1), + torch.arange(n_pts).view(1, -1, 1), indices_sort] return p_intervals_batch, d_intervals_batch, mask_inside_cube def intersect_camera_rays_with_unit_cube( - pixels, camera_mat, world_mat, scale_mat, padding=0.1, eps=1e-6, - use_ray_length_as_depth=True): + pixels, camera_mat, world_mat, scale_mat, padding=0.1, eps=1e-6, use_ray_length_as_depth=True +): ''' Returns the intersection points of ray cast from camera origin to pixel points p on the image plane. @@ -231,24 +226,22 @@ def intersect_camera_rays_with_unit_cube( ''' batch_size, n_points, _ = pixels.shape - pixel_world = image_points_to_world( - pixels, camera_mat, world_mat, scale_mat) - camera_world = origin_to_world( - n_points, camera_mat, world_mat, scale_mat) + pixel_world = image_points_to_world(pixels, camera_mat, world_mat, scale_mat) + camera_world = origin_to_world(n_points, camera_mat, world_mat, scale_mat) ray_vector = (pixel_world - camera_world) p_cube, d_cube, mask_cube = check_ray_intersection_with_unit_cube( - camera_world, ray_vector, padding=padding, eps=eps) + camera_world, ray_vector, padding=padding, eps=eps + ) if not use_ray_length_as_depth: - p_cam = transform_to_camera_space(p_cube.view( - batch_size, -1, 3), camera_mat, world_mat, scale_mat).view( - batch_size, n_points, -1, 3) + p_cam = transform_to_camera_space( + p_cube.view(batch_size, -1, 3), camera_mat, world_mat, scale_mat + ).view(batch_size, n_points, -1, 3) d_cube = p_cam[:, :, :, -1] return p_cube, d_cube, mask_cube -def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), - subsample_to=None): +def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), subsample_to=None): ''' Arranges pixels for given resolution in range image_range. The function returns the unscaled pixel locations as integers and the @@ -266,9 +259,8 @@ def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), # Arrange pixel location in scale resolution pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h)) - pixel_locations = torch.stack( - [pixel_locations[0], pixel_locations[1]], - dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1) + pixel_locations = torch.stack([pixel_locations[0], pixel_locations[1]], + dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1) pixel_scaled = pixel_locations.clone().float() # Shift and scale points to match image_range @@ -278,10 +270,8 @@ def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc # Subsample points if subsample_to is not None and > 0 - if (subsample_to is not None and subsample_to > 0 and - subsample_to < n_points): - idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,), - replace=False) + if (subsample_to is not None and subsample_to > 0 and subsample_to < n_points): + idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to, ), replace=False) pixel_scaled = pixel_scaled[:, idx] pixel_locations = pixel_locations[:, idx] @@ -342,15 +332,13 @@ def transform_pointcloud(pointcloud, transform): transform (tensor): transformation of size 4 x 4 ''' - assert(transform.shape == (4, 4) and pointcloud.shape[-1] == 3) + assert (transform.shape == (4, 4) and pointcloud.shape[-1] == 3) pcl, is_numpy = to_pytorch(pointcloud, True) transform = to_pytorch(transform) # Transform point cloud to homogen coordinate system - pcl_hom = torch.cat([ - pcl, torch.ones(pcl.shape[0], 1) - ], dim=-1).transpose(1, 0) + pcl_hom = torch.cat([pcl, torch.ones(pcl.shape[0], 1)], dim=-1).transpose(1, 0) # Apply transformation to point cloud pcl_hom_transformed = transform @ pcl_hom @@ -371,13 +359,11 @@ def transform_points_batch(p, transform): transform (tensor): transformation of size B x 4 x 4 ''' device = p.device - assert(transform.shape[1:] == (4, 4) and p.shape[-1] - == 3 and p.shape[0] == transform.shape[0]) + assert (transform.shape[1:] == (4, 4) and p.shape[-1] == 3 and p.shape[0] == transform.shape[0]) # Transform points to homogen coordinates - pcl_hom = torch.cat([ - p, torch.ones(p.shape[0], p.shape[1], 1).to(device) - ], dim=-1).transpose(2, 1) + pcl_hom = torch.cat([p, torch.ones(p.shape[0], p.shape[1], 1).to(device)], + dim=-1).transpose(2, 1) # Apply transformation pcl_hom_transformed = transform @ pcl_hom @@ -387,8 +373,9 @@ def transform_points_batch(p, transform): return pcl_out -def get_tensor_values(tensor, p, grid_sample=True, mode='nearest', - with_mask=False, squeeze_channel_dim=False): +def get_tensor_values( + tensor, p, grid_sample=True, mode='nearest', with_mask=False, squeeze_channel_dim=False +): ''' Returns values from tensor at given location p. @@ -415,8 +402,7 @@ def get_tensor_values(tensor, p, grid_sample=True, mode='nearest', p[:, :, 0] = (p[:, :, 0] + 1) * (w) / 2 p[:, :, 1] = (p[:, :, 1] + 1) * (h) / 2 p = p.long() - values = tensor[torch.arange(batch_size).unsqueeze(-1), :, p[:, :, 1], - p[:, :, 0]] + values = tensor[torch.arange(batch_size).unsqueeze(-1), :, p[:, :, 1], p[:, :, 0]] if with_mask: mask = get_mask(values) @@ -436,8 +422,7 @@ def get_tensor_values(tensor, p, grid_sample=True, mode='nearest', return values -def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat, - invert=True): +def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat, invert=True): ''' Transforms pixel positions p with given depth value d to world coordinates. Args: @@ -448,7 +433,7 @@ def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat, scale_mat (tensor): scale matrix invert (bool): whether to invert matrices (default: true) ''' - assert(pixels.shape[-1] == 2) + assert (pixels.shape[-1] == 2) # Convert to pytorch pixels, is_numpy = to_pytorch(pixels, True) @@ -493,8 +478,8 @@ def transform_to_camera_space(p_world, camera_mat, world_mat, scale_mat): device = p_world.device # Transform world points to homogen coordinates - p_world = torch.cat([p_world, torch.ones( - batch_size, n_p, 1).to(device)], dim=-1).permute(0, 2, 1) + p_world = torch.cat([p_world, torch.ones(batch_size, n_p, 1).to(device)], + dim=-1).permute(0, 2, 1) # Apply matrices to transform p_world to camera space p_cam = camera_mat @ world_mat @ scale_mat @ p_world @@ -536,8 +521,7 @@ def origin_to_world(n_points, camera_mat, world_mat, scale_mat, invert=True): return p_world -def image_points_to_world(image_points, camera_mat, world_mat, scale_mat, - invert=True): +def image_points_to_world(image_points, camera_mat, world_mat, scale_mat, invert=True): ''' Transforms points on image plane to world coordinates. In contrast to transform_to_world, no depth value is needed as points on @@ -551,12 +535,13 @@ def image_points_to_world(image_points, camera_mat, world_mat, scale_mat, invert (bool): whether to invert matrices (default: true) ''' batch_size, n_pts, dim = image_points.shape - assert(dim == 2) + assert (dim == 2) device = image_points.device d_image = torch.ones(batch_size, n_pts, 1).to(device) - return transform_to_world(image_points, d_image, camera_mat, world_mat, - scale_mat, invert=invert) + return transform_to_world( + image_points, d_image, camera_mat, world_mat, scale_mat, invert=invert + ) def check_weights(params): @@ -602,7 +587,7 @@ def get_logits_from_prob(probs, eps=1e-4): probs (tensor): probability tensor eps (float): epsilon value for numerical stability ''' - probs = np.clip(probs, a_min=eps, a_max=1-eps) + probs = np.clip(probs, a_min=eps, a_max=1 - eps) logits = np.log(probs / (1 - probs)) return logits @@ -629,7 +614,7 @@ def chamfer_distance_naive(points1, points2): points1 (numpy array): first point set points2 (numpy array): second point set ''' - assert(points1.size() == points2.size()) + assert (points1.size() == points2.size()) batch_size, T, _ = points1.size() points1 = points1.view(batch_size, T, 1, 3) @@ -748,10 +733,16 @@ def make_3d_grid(bb_min, bb_max, shape): return p -def get_occupancy_loss_points(pixels, camera_mat, world_mat, scale_mat, - depth_image=None, use_cube_intersection=True, - occupancy_random_normal=False, - depth_range=[0, 2.4]): +def get_occupancy_loss_points( + pixels, + camera_mat, + world_mat, + scale_mat, + depth_image=None, + use_cube_intersection=True, + occupancy_random_normal=False, + depth_range=[0, 2.4] +): ''' Returns 3D points for occupancy loss. Args: @@ -794,16 +785,19 @@ def get_occupancy_loss_points(pixels, camera_mat, world_mat, scale_mat, if depth_image is not None: depth_gt, mask_gt_depth = get_tensor_values( - depth_image, pixels, squeeze_channel_dim=True, with_mask=True) + depth_image, pixels, squeeze_channel_dim=True, with_mask=True + ) d_occupancy[mask_gt_depth] = depth_gt[mask_gt_depth] - p_occupancy = transform_to_world(pixels, d_occupancy.unsqueeze(-1), - camera_mat, world_mat, scale_mat) + p_occupancy = transform_to_world( + pixels, d_occupancy.unsqueeze(-1), camera_mat, world_mat, scale_mat + ) return p_occupancy -def get_freespace_loss_points(pixels, camera_mat, world_mat, scale_mat, - use_cube_intersection=True, depth_range=[0, 2.4]): +def get_freespace_loss_points( + pixels, camera_mat, world_mat, scale_mat, use_cube_intersection=True, depth_range=[0, 2.4] +): ''' Returns 3D points for freespace loss. Args: @@ -832,7 +826,8 @@ def get_freespace_loss_points(pixels, camera_mat, world_mat, scale_mat, device) * (d_cube[:, 1] - d_cube[:, 0]) p_freespace = transform_to_world( - pixels, d_freespace.unsqueeze(-1), camera_mat, world_mat, scale_mat) + pixels, d_freespace.unsqueeze(-1), camera_mat, world_mat, scale_mat + ) return p_freespace @@ -844,7 +839,6 @@ def normalize_tensor(tensor, min_norm=1e-5, feat_dim=-1): min_norm (float): minimum norm for numerical stability feat_dim (int): feature dimension in tensor (default: -1) ''' - norm_tensor = torch.clamp(torch.norm(tensor, dim=feat_dim, keepdim=True), - min=min_norm) + norm_tensor = torch.clamp(torch.norm(tensor, dim=feat_dim, keepdim=True), min=min_norm) normed_tensor = tensor / norm_tensor return normed_tensor diff --git a/lib/pymafx/utils/data_loader.py b/lib/pymafx/utils/data_loader.py index 2c34d300c43f15d7de24460f19f1a6da7d483d60..cc92ad223836e9de322bc80bbab887bb9ec3f17b 100644 --- a/lib/pymafx/utils/data_loader.py +++ b/lib/pymafx/utils/data_loader.py @@ -3,47 +3,57 @@ import torch from torch.utils.data import DataLoader from torch.utils.data.sampler import Sampler -class RandomSampler(Sampler): +class RandomSampler(Sampler): def __init__(self, data_source, checkpoint): self.data_source = data_source if checkpoint is not None and checkpoint['dataset_perm'] is not None: self.dataset_perm = checkpoint['dataset_perm'] - self.perm = self.dataset_perm[checkpoint['batch_size']*checkpoint['batch_idx']:] + self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] else: self.dataset_perm = torch.randperm(len(self.data_source)).tolist() - self.perm = torch.randperm(len(self.data_source)).tolist() + self.perm = torch.randperm(len(self.data_source)).tolist() def __iter__(self): return iter(self.perm) - + def __len__(self): return len(self.perm) -class SequentialSampler(Sampler): +class SequentialSampler(Sampler): def __init__(self, data_source, checkpoint): self.data_source = data_source if checkpoint is not None and checkpoint['dataset_perm'] is not None: self.dataset_perm = checkpoint['dataset_perm'] - self.perm = self.dataset_perm[checkpoint['batch_size']*checkpoint['batch_idx']:] + self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] else: self.dataset_perm = list(range(len(self.data_source))) self.perm = self.dataset_perm def __iter__(self): return iter(self.perm) - + def __len__(self): return len(self.perm) + class CheckpointDataLoader(DataLoader): """ Extends torch.utils.data.DataLoader to handle resuming training from an arbitrary point within an epoch. """ - def __init__(self, dataset, checkpoint=None, batch_size=1, - shuffle=False, num_workers=0, pin_memory=False, drop_last=True, - timeout=0, worker_init_fn=None): + def __init__( + self, + dataset, + checkpoint=None, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=False, + drop_last=True, + timeout=0, + worker_init_fn=None + ): if shuffle: sampler = RandomSampler(dataset, checkpoint) @@ -54,5 +64,14 @@ class CheckpointDataLoader(DataLoader): else: self.checkpoint_batch_idx = 0 - super(CheckpointDataLoader, self).__init__(dataset, sampler=sampler, shuffle=False, batch_size=batch_size, num_workers=num_workers, - drop_last=drop_last, pin_memory=pin_memory, timeout=timeout, worker_init_fn=None) + super(CheckpointDataLoader, self).__init__( + dataset, + sampler=sampler, + shuffle=False, + batch_size=batch_size, + num_workers=num_workers, + drop_last=drop_last, + pin_memory=pin_memory, + timeout=timeout, + worker_init_fn=None + ) diff --git a/lib/pymafx/utils/demo_utils.py b/lib/pymafx/utils/demo_utils.py index 40ec3d576e9f93d81ac789f74c5b54feb711e0ae..b1ad8da91c7a7f6f67d4770c9866a02a78aa5275 100644 --- a/lib/pymafx/utils/demo_utils.py +++ b/lib/pymafx/utils/demo_utils.py @@ -46,8 +46,8 @@ def preprocess_video(video, joints2d, bboxes, frames, scale=1.0, crop_size=224): if joints2d is not None: bboxes, time_pt1, time_pt2 = get_all_bbox_params(joints2d, vis_thresh=0.3) - bboxes[:,2:] = 150. / bboxes[:,2:] - bboxes = np.stack([bboxes[:,0], bboxes[:,1], bboxes[:,2], bboxes[:,2]]).T + bboxes[:, 2:] = 150. / bboxes[:, 2:] + bboxes = np.stack([bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 2]]).T video = video[time_pt1:time_pt2] joints2d = joints2d[time_pt1:time_pt2] @@ -66,11 +66,8 @@ def preprocess_video(video, joints2d, bboxes, frames, scale=1.0, crop_size=224): j2d = joints2d[idx] if joints2d is not None else None norm_img, raw_img, kp_2d = get_single_image_crop_demo( - img, - bbox, - kp_2d=j2d, - scale=scale, - crop_size=crop_size) + img, bbox, kp_2d=j2d, scale=scale, crop_size=crop_size + ) if joints2d is not None: joints2d[idx] = kp_2d @@ -88,16 +85,16 @@ def download_youtube_clip(url, download_folder): def smplify_runner( - pred_rotmat, - pred_betas, - pred_cam, - j2d, - device, - batch_size, - lr=1.0, - opt_steps=1, - use_lbfgs=True, - pose2aa=True + pred_rotmat, + pred_betas, + pred_cam, + j2d, + device, + batch_size, + lr=1.0, + opt_steps=1, + use_lbfgs=True, + pose2aa=True ): smplify = TemporalSMPLify( step_size=lr, @@ -106,7 +103,7 @@ def smplify_runner( focal_length=5000., use_lbfgs=use_lbfgs, device=device, - # max_iter=10, + # max_iter=10, ) # Convert predicted rotation matrices to axis-angle if pose2aa: @@ -115,18 +112,16 @@ def smplify_runner( pred_pose = pred_rotmat # Calculate camera parameters for smplify - pred_cam_t = torch.stack([ - pred_cam[:, 1], pred_cam[:, 2], - 2 * 5000 / (224 * pred_cam[:, 0] + 1e-9) - ], dim=-1) + pred_cam_t = torch.stack( + [pred_cam[:, 1], pred_cam[:, 2], 2 * 5000 / (224 * pred_cam[:, 0] + 1e-9)], dim=-1 + ) gt_keypoints_2d_orig = j2d # Before running compute reprojection error of the network opt_joint_loss = smplify.get_fitting_loss( - pred_pose.detach(), pred_betas.detach(), - pred_cam_t.detach(), - 0.5 * 224 * torch.ones(batch_size, 2, device=device), - gt_keypoints_2d_orig).mean(dim=-1) + pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), + 0.5 * 224 * torch.ones(batch_size, 2, device=device), gt_keypoints_2d_orig + ).mean(dim=-1) best_prediction_id = torch.argmin(opt_joint_loss).item() pred_betas = pred_betas[best_prediction_id].unsqueeze(0) @@ -140,7 +135,8 @@ def smplify_runner( # new_opt_pose, new_opt_betas, \ # new_opt_cam_t, \ output, new_opt_joint_loss = smplify( - pred_pose.detach(), pred_betas.detach(), + pred_pose.detach(), + pred_betas.detach(), pred_cam_t.detach(), 0.5 * 224 * torch.ones(batch_size, 2, device=device), gt_keypoints_2d_orig, @@ -152,29 +148,34 @@ def smplify_runner( update = (new_opt_joint_loss < opt_joint_loss) new_opt_vertices = output['verts'] - new_opt_cam_t = output['theta'][:,:3] - new_opt_pose = output['theta'][:,3:75] - new_opt_betas = output['theta'][:,75:] + new_opt_cam_t = output['theta'][:, :3] + new_opt_pose = output['theta'][:, 3:75] + new_opt_betas = output['theta'][:, 75:] new_opt_joints3d = output['kp_3d'] return_val = [ - update, new_opt_vertices.cpu(), new_opt_cam_t.cpu(), - new_opt_pose.cpu(), new_opt_betas.cpu(), new_opt_joints3d.cpu(), - new_opt_joint_loss, opt_joint_loss, + update, + new_opt_vertices.cpu(), + new_opt_cam_t.cpu(), + new_opt_pose.cpu(), + new_opt_betas.cpu(), + new_opt_joints3d.cpu(), + new_opt_joint_loss, + opt_joint_loss, ] return return_val def trim_videos(filename, start_time, end_time, output_filename): - command = ['ffmpeg', - '-i', '"%s"' % filename, - '-ss', str(start_time), - '-t', str(end_time - start_time), - '-c:v', 'libx264', '-c:a', 'copy', - '-threads', '1', - '-loglevel', 'panic', - '"%s"' % output_filename] + command = [ + 'ffmpeg', '-i', + '"%s"' % filename, '-ss', + str(start_time), '-t', + str(end_time - start_time), '-c:v', 'libx264', '-c:a', 'copy', '-threads', '1', '-loglevel', + 'panic', + '"%s"' % output_filename + ] # command = ' '.join(command) subprocess.call(command) @@ -187,11 +188,7 @@ def video_to_images(vid_file, img_folder=None, return_info=False): print(img_folder) os.makedirs(img_folder, exist_ok=True) - command = ['ffmpeg', - '-i', vid_file, - '-f', 'image2', - '-v', 'error', - f'{img_folder}/%06d.png'] + command = ['ffmpeg', '-i', vid_file, '-f', 'image2', '-v', 'error', f'{img_folder}/%06d.png'] print(f'Running \"{" ".join(command)}\"') try: @@ -236,8 +233,24 @@ def images_to_video(img_folder, output_vid_file): os.makedirs(img_folder, exist_ok=True) command = [ - 'ffmpeg', '-y', '-threads', '16', '-i', f'{img_folder}/%06d.png', '-profile:v', 'baseline', - '-level', '3.0', '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-an', '-v', 'error', output_vid_file, + 'ffmpeg', + '-y', + '-threads', + '16', + '-i', + f'{img_folder}/%06d.png', + '-profile:v', + 'baseline', + '-level', + '3.0', + '-c:v', + 'libx264', + '-pix_fmt', + 'yuv420p', + '-an', + '-v', + 'error', + output_vid_file, ] print(f'Running \"{" ".join(command)}\"') @@ -257,12 +270,12 @@ def convert_crop_cam_to_orig_img(cam, bbox, img_width, img_height): :param img_height (int): original image height :return: ''' - cx, cy, h = bbox[:,0], bbox[:,1], bbox[:,2] + cx, cy, h = bbox[:, 0], bbox[:, 1], bbox[:, 2] hw, hh = img_width / 2., img_height / 2. - sx = cam[:,0] * (1. / (img_width / h)) - sy = cam[:,0] * (1. / (img_height / h)) - tx = ((cx - hw) / hw / sx) + cam[:,1] - ty = ((cy - hh) / hh / sy) + cam[:,2] + sx = cam[:, 0] * (1. / (img_width / h)) + sy = cam[:, 0] * (1. / (img_height / h)) + tx = ((cx - hw) / hw / sx) + cam[:, 1] + ty = ((cy - hh) / hh / sy) + cam[:, 2] orig_cam = np.stack([sx, sy, tx, ty]).T return orig_cam @@ -272,19 +285,24 @@ def prepare_rendering_results(results_dict, nframes): for person_id, person_data in results_dict.items(): for idx, frame_id in enumerate(person_data['frame_ids']): frame_results[frame_id][person_id] = { - 'verts': person_data['verts'][idx], - 'smplx_verts': person_data['smplx_verts'][idx] if 'smplx_verts' in person_data else None, - 'cam': person_data['orig_cam'][idx], - 'cam_t': person_data['orig_cam_t'][idx] if 'orig_cam_t' in person_data else None, - # 'cam': person_data['pred_cam'][idx], + 'verts': + person_data['verts'][idx], + 'smplx_verts': + person_data['smplx_verts'][idx] if 'smplx_verts' in person_data else None, + 'cam': + person_data['orig_cam'][idx], + 'cam_t': + person_data['orig_cam_t'][idx] if 'orig_cam_t' in person_data else None, + # 'cam': person_data['pred_cam'][idx], } # naive depth ordering based on the scale of the weak perspective camera for frame_id, frame_data in enumerate(frame_results): # sort based on y-scale of the cam in original image coords - sort_idx = np.argsort([v['cam'][1] for k,v in frame_data.items()]) + sort_idx = np.argsort([v['cam'][1] for k, v in frame_data.items()]) frame_results[frame_id] = OrderedDict( - {list(frame_data.keys())[i]:frame_data[list(frame_data.keys())[i]] for i in sort_idx} + {list(frame_data.keys())[i]: frame_data[list(frame_data.keys())[i]] + for i in sort_idx} ) return frame_results diff --git a/lib/pymafx/utils/densepose_methods.py b/lib/pymafx/utils/densepose_methods.py index 3d12827674b12784a83da625b5e5fc50c1481e28..93fdf66a6651dcfe05f6e95c55379eaa00c52cb0 100644 --- a/lib/pymafx/utils/densepose_methods.py +++ b/lib/pymafx/utils/densepose_methods.py @@ -23,8 +23,9 @@ class DensePoseMethods: self.All_vertices = ALP_UV['All_vertices'][0] ## Info to compute symmetries. self.SemanticMaskSymmetries = [0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14] - self.Index_Symmetry_List = [1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, - 23]; + self.Index_Symmetry_List = [ + 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23 + ] UV_symmetry_filename = os.path.join('./data/UV_data', 'UV_symmetry_transforms.mat') self.UV_symmetry_transformations = loadmat(UV_symmetry_filename) @@ -65,17 +66,17 @@ class DensePoseMethods: vCrossW = np.cross(v, w) vCrossU = np.cross(v, u) if (np.dot(vCrossW, vCrossU) < 0): - return False; + return False # uCrossW = np.cross(u, w) uCrossV = np.cross(u, v) # if (np.dot(uCrossW, uCrossV) < 0): - return False; + return False # - denom = np.sqrt((uCrossV ** 2).sum()) - r = np.sqrt((vCrossW ** 2).sum()) / denom - t = np.sqrt((uCrossW ** 2).sum()) / denom + denom = np.sqrt((uCrossV**2).sum()) + r = np.sqrt((vCrossW**2).sum()) / denom + t = np.sqrt((uCrossW**2).sum()) / denom # return ((r <= 1) & (t <= 1) & (r + t <= 1)) @@ -90,9 +91,9 @@ class DensePoseMethods: uCrossW = np.cross(u, w) uCrossV = np.cross(u, v) # - denom = np.sqrt((uCrossV ** 2).sum()) - r = np.sqrt((vCrossW ** 2).sum()) / denom - t = np.sqrt((uCrossW ** 2).sum()) / denom + denom = np.sqrt((uCrossV**2).sum()) + r = np.sqrt((vCrossW**2).sum()) / denom + t = np.sqrt((uCrossW**2).sum()) / denom # return (1 - (r + t), r, t) @@ -101,12 +102,24 @@ class DensePoseMethods: FaceIndicesNow = np.where(self.FaceIndices == I_point) FacesNow = self.FacesDensePose[FaceIndicesNow] # - P_0 = np.vstack((self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0], - np.zeros(self.U_norm[FacesNow][:, 0].shape))).transpose() - P_1 = np.vstack((self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1], - np.zeros(self.U_norm[FacesNow][:, 1].shape))).transpose() - P_2 = np.vstack((self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2], - np.zeros(self.U_norm[FacesNow][:, 2].shape))).transpose() + P_0 = np.vstack( + ( + self.U_norm[FacesNow][:, 0], self.V_norm[FacesNow][:, 0], + np.zeros(self.U_norm[FacesNow][:, 0].shape) + ) + ).transpose() + P_1 = np.vstack( + ( + self.U_norm[FacesNow][:, 1], self.V_norm[FacesNow][:, 1], + np.zeros(self.U_norm[FacesNow][:, 1].shape) + ) + ).transpose() + P_2 = np.vstack( + ( + self.U_norm[FacesNow][:, 2], self.V_norm[FacesNow][:, 2], + np.zeros(self.U_norm[FacesNow][:, 2].shape) + ) + ).transpose() # for i, [P0, P1, P2] in enumerate(zip(P_0, P_1, P_2)): @@ -116,9 +129,12 @@ class DensePoseMethods: # # If the found UV is not inside any faces, select the vertex that is closest! # - D1 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_0[:, 0:2]).squeeze() - D2 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_1[:, 0:2]).squeeze() - D3 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], P_2[:, 0:2]).squeeze() + D1 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], + P_0[:, 0:2]).squeeze() + D2 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], + P_1[:, 0:2]).squeeze() + D3 = scipy.spatial.distance.cdist(np.array([U_point, V_point])[np.newaxis, :], + P_2[:, 0:2]).squeeze() # minD1 = D1.min() minD2 = D2.min() diff --git a/lib/pymafx/utils/geometry.py b/lib/pymafx/utils/geometry.py index 804b08cf8dc3acc0f69b55ae672f3000b6e878b5..608288fc4d73a4918ab95938a7bf5dbe98ce606f 100644 --- a/lib/pymafx/utils/geometry.py +++ b/lib/pymafx/utils/geometry.py @@ -43,11 +43,13 @@ def quat_to_rotmat(quat): wx, wy, wz = w * x, w * y, w * z xy, xz, yz = x * y, x * z, y * z - rotMat = torch.stack([ - w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, - 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 - ], - dim=1).view(B, 3, 3) + rotMat = torch.stack( + [ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 + ], + dim=1 + ).view(B, 3, 3) return rotMat @@ -74,7 +76,8 @@ def rotation_matrix_to_angle_axis(rotation_matrix): if rotation_matrix.shape[1:] == (3, 3): rot_mat = rotation_matrix.reshape(-1, 3, 3) hom = torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device).reshape( - 1, 3, 1).expand(rot_mat.shape[0], -1, -1) + 1, 3, 1 + ).expand(rot_mat.shape[0], -1, -1) rotation_matrix = torch.cat([rot_mat, hom], dim=-1) quaternion = rotation_matrix_to_quaternion(rotation_matrix) @@ -109,8 +112,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) if not quaternion.shape[-1] == 4: - raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format( - quaternion.shape)) + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape) + ) # unpack input and compute conversion q1: torch.Tensor = quaternion[..., 1] q2: torch.Tensor = quaternion[..., 2] @@ -119,8 +123,9 @@ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) cos_theta: torch.Tensor = quaternion[..., 0] - two_theta: torch.Tensor = 2.0 * torch.where(cos_theta < 0.0, torch.atan2( - -sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta)) + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta) + ) k_pos: torch.Tensor = two_theta / sin_theta k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) @@ -155,8 +160,9 @@ def quaternion_to_angle(quaternion: torch.Tensor) -> torch.Tensor: raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) if not quaternion.shape[-1] == 4: - raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format( - quaternion.shape)) + raise ValueError( + "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape) + ) # unpack input and compute conversion q1: torch.Tensor = quaternion[..., 1] q2: torch.Tensor = quaternion[..., 2] @@ -165,8 +171,9 @@ def quaternion_to_angle(quaternion: torch.Tensor) -> torch.Tensor: sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) cos_theta: torch.Tensor = quaternion[..., 0] - theta: torch.Tensor = 2.0 * torch.where(cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), - torch.atan2(sin_theta, cos_theta)) + theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta) + ) # theta: torch.Tensor = 2.0 * torch.atan2(sin_theta, cos_theta) @@ -202,8 +209,9 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))) if len(rotation_matrix.shape) > 3: - raise ValueError("Input size must be a three dimensional tensor. Got {}".format( - rotation_matrix.shape)) + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape) + ) # if not rotation_matrix.shape[-2:] == (3, 4): # raise ValueError( # "Input size must be a N x 3 x 4 tensor. Got {}".format( @@ -217,31 +225,39 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] - q0 = torch.stack([ - rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], - rmat_t[:, 2, 0] + rmat_t[:, 0, 2] - ], -1) + q0 = torch.stack( + [ + rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2] + ], -1 + ) t0_rep = t0.repeat(4, 1).t() t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] - q1 = torch.stack([ - rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, - rmat_t[:, 1, 2] + rmat_t[:, 2, 1] - ], -1) + q1 = torch.stack( + [ + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, + rmat_t[:, 1, 2] + rmat_t[:, 2, 1] + ], -1 + ) t1_rep = t1.repeat(4, 1).t() t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] - q2 = torch.stack([ - rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], - rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 - ], -1) + q2 = torch.stack( + [ + rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2 + ], -1 + ) t2_rep = t2.repeat(4, 1).t() t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] - q3 = torch.stack([ - t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], - rmat_t[:, 0, 1] - rmat_t[:, 1, 0] - ], -1) + q3 = torch.stack( + [ + t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0] + ], -1 + ) t3_rep = t3.repeat(4, 1).t() mask_c0 = mask_d2 * mask_d0_d1 @@ -254,8 +270,10 @@ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): mask_c3 = mask_c3.view(-1, 1).type_as(q3) q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 - q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa - t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q /= torch.sqrt( + t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3 + ) # noqa q *= 0.5 return q @@ -303,11 +321,13 @@ def quaternion_to_rotation_matrix(quat): wx, wy, wz = w * x, w * y, w * z xy, xz, yz = x * y, x * z, y * z - rotMat = torch.stack([ - w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, - 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 - ], - dim=1).view(B, 3, 3) + rotMat = torch.stack( + [ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2 + ], + dim=1 + ).view(B, 3, 3) return rotMat @@ -386,16 +406,18 @@ def projection(pred_joints, pred_camera, retain_z=False, iwp_mode=True): if iwp_mode: cam_sxy = pred_camera['cam_sxy'] pred_cam_t = torch.stack( - [cam_sxy[:, 1], cam_sxy[:, 2], 2 * 5000. / (224. * cam_sxy[:, 0] + 1e-9)], dim=-1) + [cam_sxy[:, 1], cam_sxy[:, 2], 2 * 5000. / (224. * cam_sxy[:, 0] + 1e-9)], dim=-1 + ) camera_center = torch.zeros(batch_size, 2) - pred_keypoints_2d = perspective_projection(pred_joints, - rotation=torch.eye(3).unsqueeze(0).expand( - batch_size, -1, -1).to(pred_joints.device), - translation=pred_cam_t, - focal_length=5000., - camera_center=camera_center, - retain_z=retain_z) + pred_keypoints_2d = perspective_projection( + pred_joints, + rotation=torch.eye(3).unsqueeze(0).expand(batch_size, -1, -1).to(pred_joints.device), + translation=pred_cam_t, + focal_length=5000., + camera_center=camera_center, + retain_z=retain_z + ) # # Normalize keypoints to [-1,1] # pred_keypoints_2d = pred_keypoints_2d / (224. / 2.) else: @@ -427,13 +449,15 @@ def projection(pred_joints, pred_camera, retain_z=False, iwp_mode=True): return pred_keypoints_2d -def perspective_projection(points, - rotation, - translation, - focal_length=None, - camera_center=None, - cam_intrinsics=None, - retain_z=False): +def perspective_projection( + points, + rotation, + translation, + focal_length=None, + camera_center=None, + cam_intrinsics=None, + retain_z=False +): """ This function computes the perspective projection of a set of points. Input: @@ -513,10 +537,12 @@ def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_si weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1) # least squares - Q = np.array([ - F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints), - O - np.reshape(joints_2d, -1) - ]).T + Q = np.array( + [ + F * np.tile(np.array([1, 0]), num_joints), F * np.tile(np.array([0, 1]), num_joints), + O - np.reshape(joints_2d, -1) + ] + ).T c = (np.reshape(joints_2d, -1) - O) * Z - F * XY # weighted least squares @@ -570,11 +596,9 @@ def estimate_translation(S, joints_2d, focal_length=5000., img_size=224., use_al S_i = S[i] joints_i = joints_2d[i] conf_i = joints_conf[i] - trans[i] = estimate_translation_np(S_i, - joints_i, - conf_i, - focal_length=focal_length[i], - img_size=img_size[i]) + trans[i] = estimate_translation_np( + S_i, joints_i, conf_i, focal_length=focal_length[i], img_size=img_size[i] + ) return torch.from_numpy(trans).to(device) @@ -585,8 +609,10 @@ def Rot_y(angle, category='torch', prepend_dim=True, device=None): prepend_dim: prepend an extra dimension Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) ''' - m = np.array([[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], - [-np.sin(angle), 0., np.cos(angle)]]) + m = np.array( + [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0., + np.cos(angle)]] + ) if category == 'torch': if prepend_dim: return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) @@ -608,8 +634,10 @@ def Rot_x(angle, category='torch', prepend_dim=True, device=None): prepend_dim: prepend an extra dimension Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) ''' - m = np.array([[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], - [0., np.sin(angle), np.cos(angle)]]) + m = np.array( + [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle), + np.cos(angle)]] + ) if category == 'torch': if prepend_dim: return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) @@ -631,8 +659,9 @@ def Rot_z(angle, category='torch', prepend_dim=True, device=None): prepend_dim: prepend an extra dimension Return: Rotation matrix with shape [1, 3, 3] (prepend_dim=True) ''' - m = np.array([[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), - np.cos(angle), 0.], [0., 0., 1.]]) + m = np.array( + [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]] + ) if category == 'torch': if prepend_dim: return torch.tensor(m, dtype=torch.float, device=device).unsqueeze(0) @@ -674,7 +703,7 @@ def compute_twist_rotation(rotation_matrix, twist_axis): twist_rotation = quaternion_to_rotation_matrix(twist_quaternion) twist_aa = quaternion_to_angle_axis(twist_quaternion) - twist_angle = torch.sum(twist_aa, dim=1, keepdim=True) / torch.sum( - twist_axis, dim=1, keepdim=True) + twist_angle = torch.sum(twist_aa, dim=1, + keepdim=True) / torch.sum(twist_axis, dim=1, keepdim=True) return twist_rotation, twist_angle diff --git a/lib/pymafx/utils/imutils.py b/lib/pymafx/utils/imutils.py index ad75ff1d206ef344a98abcc2dc32fa4c7fcb6012..b3522fee118cf47c5101bfd8e16991e5c30f58ad 100644 --- a/lib/pymafx/utils/imutils.py +++ b/lib/pymafx/utils/imutils.py @@ -9,6 +9,7 @@ from PIL import Image from lib.pymafx.core import constants + def get_transform(center, scale, res, rot=0): """Generate transformation matrix.""" h = 200 * scale @@ -19,29 +20,31 @@ def get_transform(center, scale, res, rot=0): t[1, 2] = res[0] * (-float(center[1]) / h + .5) t[2, 2] = 1 if not rot == 0: - t = np.dot(get_rot_transf(res, rot),t) + t = np.dot(get_rot_transf(res, rot), t) return t + def get_rot_transf(res, rot): """Generate rotation transformation matrix.""" if rot == 0: return np.identity(3) - rot = -rot # To match direction of rotation from cropping - rot_mat = np.zeros((3,3)) + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) rot_rad = rot * np.pi / 180 - sn,cs = np.sin(rot_rad), np.cos(rot_rad) - rot_mat[0,:2] = [cs, -sn] - rot_mat[1,:2] = [sn, cs] - rot_mat[2,2] = 1 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 # Need to rotate around center t_mat = np.eye(3) - t_mat[0,2] = -res[1]/2 - t_mat[1,2] = -res[0]/2 + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 t_inv = t_mat.copy() - t_inv[:2,2] *= -1 - rot_transf = np.dot(t_inv,np.dot(rot_mat,t_mat)) + t_inv[:2, 2] *= -1 + rot_transf = np.dot(t_inv, np.dot(rot_mat, t_mat)) return rot_transf + def transform(pt, center, scale, res, invert=0, rot=0): """Transform pixel location to different reference.""" t = get_transform(center, scale, res, rot=rot) @@ -51,6 +54,7 @@ def transform(pt, center, scale, res, invert=0, rot=0): new_pt = np.dot(t, new_pt) return new_pt[:2].astype(int) + 1 + def transform_pts(coords, center, scale, res, invert=0, rot=0): """Transform coordinates (N x 2) to different reference.""" new_coords = coords.copy() @@ -58,14 +62,14 @@ def transform_pts(coords, center, scale, res, invert=0, rot=0): new_coords[p, 0:2] = transform(coords[p, 0:2], center, scale, res, invert, rot) return new_coords + def crop(img, center, scale, res, rot=0): """Crop image according to the supplied bounding box.""" # Upper left point - ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1 # Bottom right point - br = np.array(transform([res[0]+1, - res[1]+1], center, scale, res, invert=1))-1 - + br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1 + # Padding so that when rotated proper amount of context is included pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) if not rot == 0: @@ -84,8 +88,7 @@ def crop(img, center, scale, res, rot=0): old_x = max(0, ul[0]), min(len(img[0]), br[0]) old_y = max(0, ul[1]), min(len(img), br[1]) - new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], - old_x[0]:old_x[1]] + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] if not rot == 0: # Remove padding @@ -95,15 +98,16 @@ def crop(img, center, scale, res, rot=0): new_img_resized = np.array(Image.fromarray(new_img.astype(np.uint8)).resize(res)) return new_img_resized, new_img, new_shape + def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): """'Undo' the image cropping/resizing. This function is used when evaluating mask/part segmentation. """ res = img.shape[:2] # Upper left point - ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1 # Bottom right point - br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1 + br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1 # size of cropped image crop_shape = [br[1] - ul[1], br[0] - ul[0]] @@ -121,19 +125,24 @@ def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] return new_img + def rot_aa(aa, rot): """Rotate axis angle parameters.""" # pose parameters - R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], - [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], - [0, 0, 1]]) + R = np.array( + [ + [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], [0, 0, 1] + ] + ) # find the rotation of the body in camera frame per_rdg, _ = cv2.Rodrigues(aa) # apply the global rotation to the global orientation - resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg)) + resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg)) aa = (resrot.T)[0] return aa + def flip_img(img): """Flip rgb images or masks. channels come last, e.g. (256,256,3). @@ -141,6 +150,7 @@ def flip_img(img): img = np.fliplr(img) return img + def flip_kp(kp, is_smpl=False, type='body'): """Flip keypoints.""" assert type in ['body', 'hand', 'face', 'feet'] @@ -164,11 +174,12 @@ def flip_kp(kp, is_smpl=False, type='body'): flipped_parts = constants.FACE_FLIP_PERM elif type == 'feet': flipped_parts = constants.FEEF_FLIP_PERM - + kp = kp[flipped_parts] - kp[:,0] = - kp[:,0] + kp[:, 0] = -kp[:, 0] return kp + def flip_pose(pose): """Flip pose. The flipping is based on SMPL parameters. @@ -180,6 +191,7 @@ def flip_pose(pose): pose[2::3] = -pose[2::3] return pose + def flip_aa(pose): """Flip aa. """ @@ -194,6 +206,7 @@ def flip_aa(pose): raise NotImplementedError return pose + def normalize_2d_kp(kp_2d, crop_size=224, inv=False): # Normalize keypoints between -1, 1 if not inv: @@ -201,10 +214,11 @@ def normalize_2d_kp(kp_2d, crop_size=224, inv=False): kp_2d = 2.0 * kp_2d * ratio - 1.0 else: ratio = 1.0 / crop_size - kp_2d = (kp_2d + 1.0)/(2*ratio) + kp_2d = (kp_2d + 1.0) / (2 * ratio) return kp_2d + def j2d_processing(kp, transf): """Process gt 2D keypoints and apply transforms.""" # nparts = kp.shape[1] @@ -212,9 +226,10 @@ def j2d_processing(kp, transf): kp_pad = torch.cat([kp, torch.ones((bs, npart, 1)).to(kp)], dim=-1) kp_new = torch.bmm(transf, kp_pad.transpose(1, 2)) kp_new = kp_new.transpose(1, 2) - kp_new[:, :, :-1] = 2.*kp_new[:, :, :-1] / constants.IMG_RES - 1. + kp_new[:, :, :-1] = 2. * kp_new[:, :, :-1] / constants.IMG_RES - 1. return kp_new[:, :, :2] + def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None): ''' param joints: [num_joints, 3] @@ -231,11 +246,9 @@ def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None): target_weight = np.ones((num_joints, 1), dtype=np.float32) if joints_vis is not None: target_weight[:, 0] = joints_vis[:, 0] - target = torch.zeros((num_joints, - heatmap_size[1], - heatmap_size[0]), - dtype=torch.float32, - device=cur_device) + target = torch.zeros( + (num_joints, heatmap_size[1], heatmap_size[0]), dtype=torch.float32, device=cur_device + ) tmp_size = sigma * 3 @@ -264,7 +277,7 @@ def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None): y = x.unsqueeze(-1) x0 = y0 = size // 2 # The gaussian is not normalized, we want the center value to equal 1 - g = torch.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + g = torch.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2)) # Usable gaussian range g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0] diff --git a/lib/pymafx/utils/io.py b/lib/pymafx/utils/io.py index 3edb5227c3c58c060646770b8757b1bc61687a6b..0926624ddeb1eccf2e9c6393595acfd34a62e84d 100644 --- a/lib/pymafx/utils/io.py +++ b/lib/pymafx/utils/io.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## - """IO utilities.""" from __future__ import absolute_import @@ -28,7 +27,7 @@ import re import sys try: from urllib.request import urlopen -except ImportError: #python2 +except ImportError: #python2 from urllib2 import urlopen logger = logging.getLogger(__name__) @@ -59,8 +58,8 @@ def cache_url(url_or_file, cache_dir): # 'bucket: {}').format(_DETECTRON_S3_BASE_URL) # # cache_file_path = url.replace(_DETECTRON_S3_BASE_URL, cache_dir) - Len_filename = len(url.split('/')[-1]) - BASE_URL = url[0:-Len_filename-1] + Len_filename = len(url.split('/')[-1]) + BASE_URL = url[0:-Len_filename - 1] # cache_file_path = url.replace(BASE_URL, cache_dir) if os.path.exists(cache_file_path): @@ -102,18 +101,13 @@ def _progress_bar(count, total): percents = round(100.0 * count / float(total), 1) bar = '=' * filled_len + '-' * (bar_len - filled_len) - sys.stdout.write( - ' [{}] {}% of {:.1f}MB file \r'. - format(bar, percents, total / 1024 / 1024) - ) + sys.stdout.write(' [{}] {}% of {:.1f}MB file \r'.format(bar, percents, total / 1024 / 1024)) sys.stdout.flush() if count >= total: sys.stdout.write('\n') -def download_url( - url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar -): +def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): """Download url and write it to dst_file_path. Credit: https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook diff --git a/lib/pymafx/utils/iuvmap.py b/lib/pymafx/utils/iuvmap.py index 7da02e6cdc1e656493f7566080f3108ca38cac87..7f7c25398e04e30b2b244d44badc83415d583852 100644 --- a/lib/pymafx/utils/iuvmap.py +++ b/lib/pymafx/utils/iuvmap.py @@ -9,11 +9,13 @@ def iuvmap_clean(U_uv, V_uv, Index_UV, AnnIndex=None): recon_Index_UV = [] for i in range(Index_UV.size(1)): if i == 0: - recon_Index_UV_i = torch.min(F.threshold(Index_UV_max + 1, 0.5, 0), - -F.threshold(-Index_UV_max - 1, -1.5, 0)) + recon_Index_UV_i = torch.min( + F.threshold(Index_UV_max + 1, 0.5, 0), -F.threshold(-Index_UV_max - 1, -1.5, 0) + ) else: - recon_Index_UV_i = torch.min(F.threshold(Index_UV_max, i - 0.5, 0), - -F.threshold(-Index_UV_max, -i - 0.5, 0)) / float(i) + recon_Index_UV_i = torch.min( + F.threshold(Index_UV_max, i - 0.5, 0), -F.threshold(-Index_UV_max, -i - 0.5, 0) + ) / float(i) recon_Index_UV.append(recon_Index_UV_i) recon_Index_UV = torch.stack(recon_Index_UV, dim=1) @@ -24,11 +26,13 @@ def iuvmap_clean(U_uv, V_uv, Index_UV, AnnIndex=None): recon_Ann_Index = [] for i in range(AnnIndex.size(1)): if i == 0: - recon_Ann_Index_i = torch.min(F.threshold(AnnIndex_max + 1, 0.5, 0), - -F.threshold(-AnnIndex_max - 1, -1.5, 0)) + recon_Ann_Index_i = torch.min( + F.threshold(AnnIndex_max + 1, 0.5, 0), -F.threshold(-AnnIndex_max - 1, -1.5, 0) + ) else: - recon_Ann_Index_i = torch.min(F.threshold(AnnIndex_max, i - 0.5, 0), - -F.threshold(-AnnIndex_max, -i - 0.5, 0)) / float(i) + recon_Ann_Index_i = torch.min( + F.threshold(AnnIndex_max, i - 0.5, 0), -F.threshold(-AnnIndex_max, -i - 0.5, 0) + ) / float(i) recon_Ann_Index.append(recon_Ann_Index_i) recon_Ann_Index = torch.stack(recon_Ann_Index, dim=1) @@ -66,8 +70,10 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N for part_id in range(0, K): CurrentU = U_uv[batch_id, part_id] CurrentV = V_uv[batch_id, part_id] - output[1, Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id] - output[2, Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id] + output[1, + Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id] + output[2, + Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id] if uv_rois is None: outputs.append(output.unsqueeze(0)) @@ -88,12 +94,16 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)] output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') paddingleft = int(0.5 * (heatmap_size - new_size[1])) - output = F.pad(output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0)) + output = F.pad( + output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0) + ) else: new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size] output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') paddingtop = int(0.5 * (heatmap_size - new_size[0])) - output = F.pad(output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop)) + output = F.pad( + output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop) + ) outputs.append(output) @@ -105,8 +115,10 @@ def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24): batch_size = uvimages.size(0) uvimg_size = uvimages.size(-1) - Index2mask = [[0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17], [16, 18], [19, 21], - [20, 22], [23, 24]] + Index2mask = [ + [0], [1, 2], [3], [4], [5], [6], [7, 9], [8, 10], [11, 13], [12, 14], [15, 17], [16, 18], + [19, 21], [20, 22], [23, 24] + ] part_ind = torch.round(uvimages[:, 0, :, :] * n_part) part_u = uvimages[:, 1, :, :] @@ -117,12 +129,15 @@ def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24): recon_Index_UV = [] recon_Ann_Index = [] - for i in range(n_part+1): + for i in range(n_part + 1): if i == 0: - recon_Index_UV_i = torch.min(F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0)) + recon_Index_UV_i = torch.min( + F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0) + ) else: - recon_Index_UV_i = torch.min(F.threshold(part_ind, i - 0.5, 0), - -F.threshold(-part_ind, -i - 0.5, 0)) / float(i) + recon_Index_UV_i = torch.min( + F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0) + ) / float(i) recon_U_i = recon_Index_UV_i * part_u recon_V_i = recon_Index_UV_i * part_v @@ -192,8 +207,12 @@ def iuv_img2map(uvimages, uv_rois=None, new_size=None, n_part=24): recon_U_roi_i = F.interpolate(recon_U_roi_i.unsqueeze(0), size=(M, M), mode='nearest') recon_V_roi_i = F.interpolate(recon_V_roi_i.unsqueeze(0), size=(M, M), mode='nearest') - recon_Index_UV_roi_i = F.interpolate(recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest') - recon_Ann_Index_roi_i = F.interpolate(recon_Ann_Index_roi_i.unsqueeze(0), size=(M, M), mode='nearest') + recon_Index_UV_roi_i = F.interpolate( + recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest' + ) + recon_Ann_Index_roi_i = F.interpolate( + recon_Ann_Index_roi_i.unsqueeze(0), size=(M, M), mode='nearest' + ) recon_U_roi.append(recon_U_roi_i) recon_V_roi.append(recon_V_roi_i) @@ -217,12 +236,15 @@ def seg_img2map(segimages, uv_rois=None, new_size=None, n_part=24): recon_Index_UV = [] - for i in range(n_part+1): + for i in range(n_part + 1): if i == 0: - recon_Index_UV_i = torch.min(F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0)) + recon_Index_UV_i = torch.min( + F.threshold(part_ind + 1, 0.5, 0), -F.threshold(-part_ind - 1, -1.5, 0) + ) else: - recon_Index_UV_i = torch.min(F.threshold(part_ind, i - 0.5, 0), - -F.threshold(-part_ind, -i - 0.5, 0)) / float(i) + recon_Index_UV_i = torch.min( + F.threshold(part_ind, i - 0.5, 0), -F.threshold(-part_ind, -i - 0.5, 0) + ) / float(i) recon_Index_UV.append(recon_Index_UV_i) @@ -262,7 +284,9 @@ def seg_img2map(segimages, uv_rois=None, new_size=None, n_part=24): recon_Index_UV_roi_i = recon_Index_UV[i, :, h_margin:h_margin + h_size, :] - recon_Index_UV_roi_i = F.interpolate(recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest') + recon_Index_UV_roi_i = F.interpolate( + recon_Index_UV_roi_i.unsqueeze(0), size=(M, M), mode='nearest' + ) recon_Index_UV_roi.append(recon_Index_UV_roi_i) diff --git a/lib/pymafx/utils/keypoints.py b/lib/pymafx/utils/keypoints.py index b505616e14436bcfecdaf3b65a18255ce98b86b0..2ab223c2bef79518adc523da1606cfc331ef8251 100644 --- a/lib/pymafx/utils/keypoints.py +++ b/lib/pymafx/utils/keypoints.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## - """Keypoint utilities (somewhat specific to COCO keypoints).""" from __future__ import absolute_import @@ -35,23 +34,9 @@ def get_keypoints(): # Keypoints are not available in the COCO json for the test split, so we # provide them here. keypoints = [ - 'nose', - 'left_eye', - 'right_eye', - 'left_ear', - 'right_ear', - 'left_shoulder', - 'right_shoulder', - 'left_elbow', - 'right_elbow', - 'left_wrist', - 'right_wrist', - 'left_hip', - 'right_hip', - 'left_knee', - 'right_knee', - 'left_ankle', - 'right_ankle' + 'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder', + 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', + 'left_knee', 'right_knee', 'left_ankle', 'right_ankle' ] keypoint_flip_map = { 'left_eye': 'right_eye', @@ -126,8 +111,7 @@ def heatmaps_to_keypoints(maps, rois): # NCHW to NHWC for use with OpenCV maps = np.transpose(maps, [0, 2, 3, 1]) min_size = cfg.KRCNN.INFERENCE_MIN_SIZE - xy_preds = np.zeros( - (len(rois), 4, cfg.KRCNN.NUM_KEYPOINTS), dtype=np.float32) + xy_preds = np.zeros((len(rois), 4, cfg.KRCNN.NUM_KEYPOINTS), dtype=np.float32) for i in range(len(rois)): if min_size > 0: roi_map_width = int(np.maximum(widths_ceil[i], min_size)) @@ -138,8 +122,8 @@ def heatmaps_to_keypoints(maps, rois): width_correction = widths[i] / roi_map_width height_correction = heights[i] / roi_map_height roi_map = cv2.resize( - maps[i], (roi_map_width, roi_map_height), - interpolation=cv2.INTER_CUBIC) + maps[i], (roi_map_width, roi_map_height), interpolation=cv2.INTER_CUBIC + ) # Bring back to CHW roi_map = np.transpose(roi_map, [2, 0, 1]) roi_map_probs = scores_to_probs(roi_map.copy()) @@ -148,8 +132,7 @@ def heatmaps_to_keypoints(maps, rois): pos = roi_map[k, :, :].argmax() x_int = pos % w y_int = (pos - x_int) // w - assert (roi_map_probs[k, y_int, x_int] == - roi_map_probs[k, :, :].max()) + assert (roi_map_probs[k, y_int, x_int] == roi_map_probs[k, :, :].max()) x = (x_int + 0.5) * width_correction y = (y_int + 0.5) * height_correction xy_preds[i, 0, k] = x + offset_x[i] @@ -201,8 +184,8 @@ def keypoints_to_heatmap_labels(keypoints, rois): valid_loc = np.logical_and( np.logical_and(x >= 0, y >= 0), - np.logical_and( - x < cfg.KRCNN.HEATMAP_SIZE, y < cfg.KRCNN.HEATMAP_SIZE)) + np.logical_and(x < cfg.KRCNN.HEATMAP_SIZE, y < cfg.KRCNN.HEATMAP_SIZE) + ) valid = np.logical_and(valid_loc, vis) valid = valid.astype(np.int32) @@ -234,9 +217,7 @@ def nms_oks(kp_predictions, rois, thresh): while order.size > 0: i = order[0] keep.append(i) - ovr = compute_oks( - kp_predictions[i], rois[i], kp_predictions[order[1:]], - rois[order[1:]]) + ovr = compute_oks(kp_predictions[i], rois[i], kp_predictions[order[1:]], rois[order[1:]]) inds = np.where(ovr <= thresh)[0] order = order[inds + 1] @@ -251,9 +232,9 @@ def compute_oks(src_keypoints, src_roi, dst_keypoints, dst_roi): dst_roi: Nx4 """ - sigmas = np.array([ - .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, - .87, .89, .89]) / 10.0 + sigmas = np.array( + [.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89] + ) / 10.0 vars = (sigmas * 2)**2 # area @@ -313,9 +294,15 @@ def generate_3d_integral_preds_tensor(heatmaps, num_joints, x_dim, y_dim, z_dim) accu_z = heatmaps.sum(dim=3) accu_z = accu_z.sum(dim=3) - accu_x = accu_x * torch.cuda.comm.broadcast(torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index])[0] - accu_y = accu_y * torch.cuda.comm.broadcast(torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index])[0] - accu_z = accu_z * torch.cuda.comm.broadcast(torch.arange(z_dim, dtype=torch.float32), devices=[accu_z.device.index])[0] + accu_x = accu_x * torch.cuda.comm.broadcast( + torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index] + )[0] + accu_y = accu_y * torch.cuda.comm.broadcast( + torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index] + )[0] + accu_z = accu_z * torch.cuda.comm.broadcast( + torch.arange(z_dim, dtype=torch.float32), devices=[accu_z.device.index] + )[0] accu_x = accu_x.sum(dim=2, keepdim=True) accu_y = accu_y.sum(dim=2, keepdim=True) @@ -326,8 +313,12 @@ def generate_3d_integral_preds_tensor(heatmaps, num_joints, x_dim, y_dim, z_dim) accu_x = heatmaps.sum(dim=2) accu_y = heatmaps.sum(dim=3) - accu_x = accu_x * torch.cuda.comm.broadcast(torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index])[0] - accu_y = accu_y * torch.cuda.comm.broadcast(torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index])[0] + accu_x = accu_x * torch.cuda.comm.broadcast( + torch.arange(x_dim, dtype=torch.float32), devices=[accu_x.device.index] + )[0] + accu_y = accu_y * torch.cuda.comm.broadcast( + torch.arange(y_dim, dtype=torch.float32), devices=[accu_y.device.index] + )[0] accu_x = accu_x.sum(dim=2, keepdim=True) accu_y = accu_y.sum(dim=2, keepdim=True) @@ -347,14 +338,18 @@ def softmax_integral_tensor(preds, num_joints, hm_width, hm_height, hm_depth=Non # integrate heatmap into joint location if output_3d: - x, y, z = generate_3d_integral_preds_tensor(preds, num_joints, hm_width, hm_height, hm_depth) + x, y, z = generate_3d_integral_preds_tensor( + preds, num_joints, hm_width, hm_height, hm_depth + ) # x = x / float(hm_width) - 0.5 # y = y / float(hm_height) - 0.5 # z = z / float(hm_depth) - 0.5 preds = torch.cat((x, y, z), dim=2) # preds = preds.reshape((preds.shape[0], num_joints * 3)) else: - x, y, _ = generate_3d_integral_preds_tensor(preds, num_joints, hm_width, hm_height, z_dim=None) + x, y, _ = generate_3d_integral_preds_tensor( + preds, num_joints, hm_width, hm_height, z_dim=None + ) # x = x / float(hm_width) - 0.5 # y = y / float(hm_height) - 0.5 preds = torch.cat((x, y), dim=2) diff --git a/lib/pymafx/utils/mesh_generation.py b/lib/pymafx/utils/mesh_generation.py index 94943f2b2fb6dc33427418f406d65a1b96efafdd..2876209e7678d2906a84850208f6c288103d07c5 100644 --- a/lib/pymafx/utils/mesh_generation.py +++ b/lib/pymafx/utils/mesh_generation.py @@ -33,13 +33,21 @@ class Generator3D(object): size for refinement process (we added this functionality in this work) ''' - - def __init__(self, model, points_batch_size=100000, - threshold=0.5, refinement_step=0, device=None, - resolution0=16, upsampling_steps=3, - with_normals=False, padding=0.1, - simplify_nfaces=None, with_color=False, - refine_max_faces=10000): + def __init__( + self, + model, + points_batch_size=100000, + threshold=0.5, + refinement_step=0, + device=None, + resolution0=16, + upsampling_steps=3, + with_normals=False, + padding=0.1, + simplify_nfaces=None, + with_color=False, + refine_max_faces=10000 + ): self.model = model.to(device) self.points_batch_size = points_batch_size self.refinement_step = refinement_step @@ -68,8 +76,7 @@ class Generator3D(object): kwargs = {} c = self.model.encode_inputs(inputs) - mesh = self.generate_from_latent(c, stats_dict=stats_dict, - data=data, **kwargs) + mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs) return mesh, stats_dict @@ -95,8 +102,7 @@ class Generator3D(object): return meshes - def generate_pointcloud(self, mesh, data=None, n_points=2000000, - scale_back=True): + def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True): ''' Generates a point cloud from the mesh. Args: @@ -117,8 +123,7 @@ class Generator3D(object): pcl_out = trimesh.Trimesh(vertices=pcl, process=False) return pcl_out - def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, - **kwargs): + def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs): ''' Generates mesh from latent. Args: @@ -135,14 +140,11 @@ class Generator3D(object): # Shortcut if self.upsampling_steps == 0: nx = self.resolution0 - pointsf = box_size * make_3d_grid( - (-0.5,)*3, (0.5,)*3, (nx,)*3 - ) + pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3) values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() value_grid = values.reshape(nx, nx, nx) else: - mesh_extractor = MISE( - self.resolution0, self.upsampling_steps, threshold) + mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold) points = mesh_extractor.query() @@ -153,8 +155,7 @@ class Generator3D(object): pointsf = 2 * pointsf / mesh_extractor.resolution pointsf = box_size * (pointsf - 1.0) # Evaluate model and update - values = self.eval_points( - pointsf, c, pl, **kwargs).cpu().numpy() + values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() values = values.astype(np.float64) mesh_extractor.update(points, values) @@ -203,17 +204,15 @@ class Generator3D(object): threshold = np.log(self.threshold) - np.log(1. - self.threshold) # Make sure that mesh is watertight t0 = time.time() - occ_hat_padded = np.pad( - occ_hat, 1, 'constant', constant_values=-1e6) - vertices, triangles = libmcubes.marching_cubes( - occ_hat_padded, threshold) + occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6) + vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold) stats_dict['time (marching cubes)'] = time.time() - t0 # Strange behaviour in libmcubes: vertices are shifted by 0.5 vertices -= 0.5 # Undo padding vertices -= 1 # Normalize to bounding box - vertices /= np.array([n_x-1, n_y-1, n_z-1]) + vertices /= np.array([n_x - 1, n_y - 1, n_z - 1]) vertices *= 2 vertices = box_size * (vertices - 1) @@ -228,10 +227,13 @@ class Generator3D(object): else: normals = None # Create mesh - mesh = trimesh.Trimesh(vertices, triangles, - vertex_normals=normals, - # vertex_colors=vertex_colors, - process=False) + mesh = trimesh.Trimesh( + vertices, + triangles, + vertex_normals=normals, + # vertex_colors=vertex_colors, + process=False + ) # Directly return if mesh is empty if vertices.shape[0] == 0: @@ -255,9 +257,12 @@ class Generator3D(object): vertex_colors = self.estimate_colors(np.array(mesh.vertices), c) stats_dict['time (color)'] = time.time() - t0 mesh = trimesh.Trimesh( - vertices=mesh.vertices, faces=mesh.faces, + vertices=mesh.vertices, + faces=mesh.faces, vertex_normals=mesh.vertex_normals, - vertex_colors=vertex_colors, process=False) + vertex_colors=vertex_colors, + process=False + ) return mesh @@ -275,16 +280,15 @@ class Generator3D(object): for vi in vertices_split: vi = vi.to(device) with torch.no_grad(): - ci = self.model.decode_color( - vi.unsqueeze(0), c).squeeze(0).cpu() + ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu() colors.append(ci) colors = np.concatenate(colors, axis=0) colors = np.clip(colors, 0, 1) colors = (colors * 255).astype(np.uint8) - colors = np.concatenate([ - colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], - axis=1) + colors = np.concatenate( + [colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1 + ) return colors def estimate_normals(self, vertices, c=None): @@ -328,7 +332,7 @@ class Generator3D(object): # Some shorthands n_x, n_y, n_z = occ_hat.shape - assert(n_x == n_y == n_z) + assert (n_x == n_y == n_z) # threshold = np.log(self.threshold) - np.log(1. - self.threshold) threshold = self.threshold @@ -348,8 +352,7 @@ class Generator3D(object): # Dataset ds_faces = TensorDataset(faces) - dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, - shuffle=True) + dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True) # We updated the refinement algorithm to subsample faces; this is # usefull when using a high extraction resolution / when working on @@ -372,13 +375,16 @@ class Generator3D(object): face_normal = face_normal / \ (face_normal.norm(dim=1, keepdim=True) + 1e-10) - face_value = torch.cat([ - torch.sigmoid(self.model.decode(p_split, c).logits) - for p_split in torch.split( - face_point.unsqueeze(0), 20000, dim=1)], dim=1) + face_value = torch.cat( + [ + torch.sigmoid(self.model.decode(p_split, c).logits) + for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1) + ], + dim=1 + ) - normal_target = -autograd.grad( - [face_value.sum()], [face_point], create_graph=True)[0] + normal_target = -autograd.grad([face_value.sum()], [face_point], + create_graph=True)[0] normal_target = \ normal_target / \ diff --git a/lib/pymafx/utils/part_utils.py b/lib/pymafx/utils/part_utils.py index 88bcf06d713ee0cbf8d673c16e1acaed9313ccab..12f0de443fa11e90674761816a644cf82a48a786 100644 --- a/lib/pymafx/utils/part_utils.py +++ b/lib/pymafx/utils/part_utils.py @@ -5,6 +5,7 @@ from core import path_config from models import SMPL + class PartRenderer(): """Renderer used to render segmentation masks and part segmentations. Internally it uses the Neural 3D Mesh Renderer @@ -14,40 +15,57 @@ class PartRenderer(): self.focal_length = focal_length self.render_res = render_res # We use Neural 3D mesh renderer for rendering masks and part segmentations - self.neural_renderer = nr.Renderer(dist_coeffs=None, orig_size=self.render_res, - image_size=render_res, - light_intensity_ambient=1, - light_intensity_directional=0, - anti_aliasing=False) - self.faces = torch.from_numpy(SMPL(path_config.SMPL_MODEL_DIR).faces.astype(np.int32)).cuda() + self.neural_renderer = nr.Renderer( + dist_coeffs=None, + orig_size=self.render_res, + image_size=render_res, + light_intensity_ambient=1, + light_intensity_directional=0, + anti_aliasing=False + ) + self.faces = torch.from_numpy(SMPL(path_config.SMPL_MODEL_DIR).faces.astype(np.int32) + ).cuda() textures = np.load(path_config.VERTEX_TEXTURE_FILE) self.textures = torch.from_numpy(textures).cuda().float() self.cube_parts = torch.cuda.FloatTensor(np.load(path_config.CUBE_PARTS_FILE)) def get_parts(self, parts, mask): """Process renderer part image to get body part indices.""" - bn,c,h,w = parts.shape - mask = mask.view(-1,1) - parts_index = torch.floor(100*parts.permute(0,2,3,1).contiguous().view(-1,3)).long() - parts = self.cube_parts[parts_index[:,0], parts_index[:,1], parts_index[:,2], None] + bn, c, h, w = parts.shape + mask = mask.view(-1, 1) + parts_index = torch.floor(100 * parts.permute(0, 2, 3, 1).contiguous().view(-1, 3)).long() + parts = self.cube_parts[parts_index[:, 0], parts_index[:, 1], parts_index[:, 2], None] parts *= mask - parts = parts.view(bn,h,w).long() + parts = parts.view(bn, h, w).long() return parts def __call__(self, vertices, camera): """Wrapper function for rendering process.""" # Estimate camera parameters given a fixed focal length - cam_t = torch.stack([camera[:,1], camera[:,2], 2*self.focal_length/(self.render_res * camera[:,0] +1e-9)],dim=-1) + cam_t = torch.stack( + [ + camera[:, 1], camera[:, 2], 2 * self.focal_length / + (self.render_res * camera[:, 0] + 1e-9) + ], + dim=-1 + ) batch_size = vertices.shape[0] K = torch.eye(3, device=vertices.device) - K[0,0] = self.focal_length - K[1,1] = self.focal_length - K[2,2] = 1 - K[0,2] = self.render_res / 2. - K[1,2] = self.render_res / 2. + K[0, 0] = self.focal_length + K[1, 1] = self.focal_length + K[2, 2] = 1 + K[0, 2] = self.render_res / 2. + K[1, 2] = self.render_res / 2. K = K[None, :, :].expand(batch_size, -1, -1) R = torch.eye(3, device=vertices.device)[None, :, :].expand(batch_size, -1, -1) faces = self.faces[None, :, :].expand(batch_size, -1, -1) - parts, _, mask = self.neural_renderer(vertices, faces, textures=self.textures.expand(batch_size, -1, -1, -1, -1, -1), K=K, R=R, t=cam_t.unsqueeze(1)) + parts, _, mask = self.neural_renderer( + vertices, + faces, + textures=self.textures.expand(batch_size, -1, -1, -1, -1, -1), + K=K, + R=R, + t=cam_t.unsqueeze(1) + ) parts = self.get_parts(parts, mask) - return mask, parts \ No newline at end of file + return mask, parts diff --git a/lib/pymafx/utils/pose_tracker.py b/lib/pymafx/utils/pose_tracker.py index 5028bb5706b9f1e9e3ccf656734650434b8c82ab..92c383cdb3dba6053a0595b9f03305c02e9fc277 100644 --- a/lib/pymafx/utils/pose_tracker.py +++ b/lib/pymafx/utils/pose_tracker.py @@ -23,10 +23,10 @@ import os.path as osp def run_openpose( - video_file, - output_folder, - staf_folder, - vis=False, + video_file, + output_folder, + staf_folder, + vis=False, ): pwd = os.getcwd() @@ -35,13 +35,10 @@ def run_openpose( render = 1 if vis else 0 display = 2 if vis else 0 cmd = [ - 'build/examples/openpose/openpose.bin', - '--model_pose', 'BODY_21A', - '--tracking', '1', - '--render_pose', str(render), - '--video', video_file, - '--write_json', output_folder, - '--display', str(display) + 'build/examples/openpose/openpose.bin', '--model_pose', 'BODY_21A', '--tracking', '1', + '--render_pose', + str(render), '--video', video_file, '--write_json', output_folder, '--display', + str(display) ] print('Executing', ' '.join(cmd)) @@ -59,7 +56,7 @@ def read_posetrack_keypoints(output_folder): # print(idx, data) for person in data['people']: person_id = person['person_id'][0] - joints2d = person['pose_keypoints_2d'] + joints2d = person['pose_keypoints_2d'] if person_id in people.keys(): people[person_id]['joints2d'].append(joints2d) people[person_id]['frames'].append(idx) @@ -72,7 +69,9 @@ def read_posetrack_keypoints(output_folder): people[person_id]['frames'].append(idx) for k in people.keys(): - people[k]['joints2d'] = np.array(people[k]['joints2d']).reshape((len(people[k]['joints2d']), -1, 3)) + people[k]['joints2d'] = np.array(people[k]['joints2d']).reshape( + (len(people[k]['joints2d']), -1, 3) + ) people[k]['frames'] = np.array(people[k]['frames']) return people @@ -80,20 +79,14 @@ def read_posetrack_keypoints(output_folder): def run_posetracker(video_file, staf_folder, posetrack_output_folder='/tmp', display=False): posetrack_output_folder = os.path.join( - posetrack_output_folder, - f'{os.path.basename(video_file)}_posetrack' + posetrack_output_folder, f'{os.path.basename(video_file)}_posetrack' ) # run posetrack on video - run_openpose( - video_file, - posetrack_output_folder, - vis=display, - staf_folder=staf_folder - ) + run_openpose(video_file, posetrack_output_folder, vis=display, staf_folder=staf_folder) people_dict = read_posetrack_keypoints(posetrack_output_folder) shutil.rmtree(posetrack_output_folder) - return people_dict \ No newline at end of file + return people_dict diff --git a/lib/pymafx/utils/pose_utils.py b/lib/pymafx/utils/pose_utils.py index f74bfd6668cb6214e4414cab095c00aee26e7314..55eb1d771376da71c864a715d1dd6b5d66e9894e 100644 --- a/lib/pymafx/utils/pose_utils.py +++ b/lib/pymafx/utils/pose_utils.py @@ -7,6 +7,7 @@ from __future__ import print_function import numpy as np import torch + def compute_similarity_transform(S1, S2): """ Computes a similarity transform (sR, t) that takes @@ -19,7 +20,7 @@ def compute_similarity_transform(S1, S2): S1 = S1.T S2 = S2.T transposed = True - assert(S2.shape[1] == S1.shape[1]) + assert (S2.shape[1] == S1.shape[1]) # 1. Remove mean. mu1 = S1.mean(axis=1, keepdims=True) @@ -47,16 +48,17 @@ def compute_similarity_transform(S1, S2): scale = np.trace(R.dot(K)) / var1 # 6. Recover translation. - t = mu2 - scale*(R.dot(mu1)) + t = mu2 - scale * (R.dot(mu1)) # 7. Error: - S1_hat = scale*R.dot(S1) + t + S1_hat = scale * R.dot(S1) + t if transposed: S1_hat = S1_hat.T return S1_hat + def compute_similarity_transform_batch(S1, S2): """Batched version of compute_similarity_transform.""" S1_hat = np.zeros_like(S1) @@ -64,10 +66,11 @@ def compute_similarity_transform_batch(S1, S2): S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) return S1_hat + def reconstruction_error(S1, S2, reduction='mean'): """Do Procrustes alignment and compute reconstruction error.""" S1_hat = compute_similarity_transform_batch(S1, S2) - re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) + re = np.sqrt(((S1_hat - S2)**2).sum(axis=-1)).mean(axis=-1) if reduction == 'mean': re = re.mean() elif reduction == 'sum': @@ -113,6 +116,7 @@ def axis_angle_add(theta, roll_axis, alpha): return c_n + def axis_angle_add_np(theta, roll_axis, alpha): """Composition of two axis-angle rotations (NumPy version) Args: @@ -145,4 +149,4 @@ def axis_angle_add_np(theta, roll_axis, alpha): c_sin = np.sin(c_angle * 0.5) c_n = (c_angle / c_sin) * c_sin_n - return c_n \ No newline at end of file + return c_n diff --git a/lib/pymafx/utils/renderer.py b/lib/pymafx/utils/renderer.py index 032deb76ef2690cdb046e67ff5d1680741dfab3a..9fb19568680b839f93c00a5288c94a5a52025242 100644 --- a/lib/pymafx/utils/renderer.py +++ b/lib/pymafx/utils/renderer.py @@ -34,33 +34,20 @@ from pytorch3d.structures.meshes import Meshes # from pytorch3d.renderer.mesh.renderer import MeshRendererWithFragments from pytorch3d.renderer import ( - look_at_view_transform, - FoVPerspectiveCameras, - PerspectiveCameras, - AmbientLights, - PointLights, - RasterizationSettings, - BlendParams, - MeshRenderer, - MeshRasterizer, - SoftPhongShader, - SoftSilhouetteShader, - HardPhongShader, - HardGouraudShader, - HardFlatShader, - TexturesVertex + look_at_view_transform, FoVPerspectiveCameras, PerspectiveCameras, AmbientLights, PointLights, + RasterizationSettings, BlendParams, MeshRenderer, MeshRasterizer, SoftPhongShader, + SoftSilhouetteShader, HardPhongShader, HardGouraudShader, HardFlatShader, TexturesVertex ) import logging + logger = logging.getLogger(__name__) + class WeakPerspectiveCamera(pyrender.Camera): - def __init__(self, - scale, - translation, - znear=pyrender.camera.DEFAULT_Z_NEAR, - zfar=None, - name=None): + def __init__( + self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None + ): super(WeakPerspectiveCamera, self).__init__( znear=znear, zfar=zfar, @@ -80,21 +67,22 @@ class WeakPerspectiveCamera(pyrender.Camera): class PyRenderer: - def __init__(self, resolution=(224,224), orig_img=False, wireframe=False, scale_ratio=1., vis_ratio=1.): + def __init__( + self, resolution=(224, 224), orig_img=False, wireframe=False, scale_ratio=1., vis_ratio=1. + ): self.resolution = (resolution[0] * scale_ratio, resolution[1] * scale_ratio) # self.scale_ratio = scale_ratio - self.faces = {'smplx': get_model_faces('smplx'), - 'smpl': get_model_faces('smpl'), - # 'mano': get_model_faces('mano'), - # 'flame': get_model_faces('flame'), - } + self.faces = { + 'smplx': get_model_faces('smplx'), + 'smpl': get_model_faces('smpl'), + # 'mano': get_model_faces('mano'), + # 'flame': get_model_faces('flame'), + } self.orig_img = orig_img self.wireframe = wireframe self.renderer = pyrender.OffscreenRenderer( - viewport_width=self.resolution[0], - viewport_height=self.resolution[1], - point_size=1.0 + viewport_width=self.resolution[0], viewport_height=self.resolution[1], point_size=1.0 ) self.vis_ratio = vis_ratio @@ -104,7 +92,7 @@ class PyRenderer: light = pyrender.PointLight(color=np.array([1.0, 1.0, 1.0]) * 0.2, intensity=1) - yrot = np.radians(120) # angle of lights + yrot = np.radians(120) # angle of lights light_pose = np.eye(4) light_pose[:3, 3] = [0, -1, 1] @@ -116,8 +104,9 @@ class PyRenderer: light_pose[:3, 3] = [1, 1, 2] self.scene.add(light, pose=light_pose) - spot_l = pyrender.SpotLight(color=np.ones(3), intensity=15.0, - innerConeAngle=np.pi/3, outerConeAngle=np.pi/2) + spot_l = pyrender.SpotLight( + color=np.ones(3), intensity=15.0, innerConeAngle=np.pi / 3, outerConeAngle=np.pi / 2 + ) light_pose[:3, 3] = [1, 2, 2] self.scene.add(spot_l, pose=light_pose) @@ -135,17 +124,34 @@ class PyRenderer: 'red': np.array([0.5, 0.2, 0.2]), 'pink': np.array([0.7, 0.5, 0.5]), 'neutral': np.array([0.7, 0.7, 0.6]), - # 'purple': np.array([0.5, 0.5, 0.7]), + # 'purple': np.array([0.5, 0.5, 0.7]), 'purple': np.array([0.55, 0.4, 0.9]), 'green': np.array([0.5, 0.55, 0.3]), 'sky': np.array([0.3, 0.5, 0.55]), 'white': np.array([1.0, 0.98, 0.94]), } - def __call__(self, verts, faces=None, img=np.zeros((224, 224, 3)), cam=np.array([1, 0, 0]), - focal_length=[5000, 5000], camera_rotation=np.eye(3), crop_info=None, - angle=None, axis=None, mesh_filename=None, color_type=None, color=[1.0, 1.0, 0.9], iwp_mode=True, crop_img=True, mesh_type='smpl', scale_ratio=1., rgba_mode=False): - + def __call__( + self, + verts, + faces=None, + img=np.zeros((224, 224, 3)), + cam=np.array([1, 0, 0]), + focal_length=[5000, 5000], + camera_rotation=np.eye(3), + crop_info=None, + angle=None, + axis=None, + mesh_filename=None, + color_type=None, + color=[1.0, 1.0, 0.9], + iwp_mode=True, + crop_img=True, + mesh_type='smpl', + scale_ratio=1., + rgba_mode=False + ): + if faces is None: faces = self.faces[mesh_type] mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) @@ -166,24 +172,28 @@ class PyRenderer: if len(cam) == 4: sx, sy, tx, ty = cam # sy = sx - camera_translation = np.array([tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]) + camera_translation = np.array( + [tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)] + ) elif len(cam) == 3: sx, tx, ty = cam sy = sx - camera_translation = np.array([- tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)]) + camera_translation = np.array( + [-tx, ty, 2 * focal_length[0] / (resolution[0] * sy + 1e-9)] + ) render_res = resolution self.renderer.viewport_width = render_res[1] self.renderer.viewport_height = render_res[0] else: if crop_info['opt_cam_t'] is None: camera_translation = convert_to_full_img_cam( - pare_cam=cam[None], - bbox_height=crop_info['bbox_scale'] * 200., - bbox_center=crop_info['bbox_center'], - img_w=crop_info['img_w'], - img_h=crop_info['img_h'], - focal_length=focal_length[0], - ) + pare_cam=cam[None], + bbox_height=crop_info['bbox_scale'] * 200., + bbox_center=crop_info['bbox_center'], + img_w=crop_info['img_w'], + img_h=crop_info['img_h'], + focal_length=focal_length[0], + ) else: camera_translation = crop_info['opt_cam_t'] if torch.is_tensor(camera_translation): @@ -197,8 +207,9 @@ class PyRenderer: self.renderer.viewport_width = render_res[1] self.renderer.viewport_height = render_res[0] camera_rotation = camera_rotation.T - camera = pyrender.IntrinsicsCamera(fx=focal_length[0], fy=focal_length[1], - cx=render_res[1]/2., cy=render_res[0]/2.) + camera = pyrender.IntrinsicsCamera( + fx=focal_length[0], fy=focal_length[1], cx=render_res[1] / 2., cy=render_res[0] / 2. + ) if color_type != None: color = self.colors_dict[color_type] @@ -237,9 +248,14 @@ class PyRenderer: for item in image_list: if scale_ratio != 1: orig_size = item.shape[:2] - item = resize(item, (orig_size[0] * scale_ratio, orig_size[1] * scale_ratio), anti_aliasing=True) + item = resize( + item, (orig_size[0] * scale_ratio, orig_size[1] * scale_ratio), + anti_aliasing=True + ) item = (item * 255).astype(np.uint8) - output_img = rgb[:, :, :-1] * valid_mask * self.vis_ratio + (1 - valid_mask * self.vis_ratio) * item + output_img = rgb[:, :, :-1] * valid_mask * self.vis_ratio + ( + 1 - valid_mask * self.vis_ratio + ) * item # output_img[valid_mask < 0.5] = item[valid_mask < 0.5] # if scale_ratio != 1: # output_img = resize(output_img, (orig_size[0], orig_size[1]), anti_aliasing=True) @@ -253,7 +269,7 @@ class PyRenderer: return_img.append(item) if type(img) is not list: - # if scale_ratio == 1: + # if scale_ratio == 1: return_img = return_img[0] self.scene.remove_node(mesh_node) @@ -267,9 +283,12 @@ class OpenDRenderer: self.resolution = (resolution[0] * ratio, resolution[1] * ratio) self.ratio = ratio self.focal_length = 5000. - self.K = np.array([[self.focal_length, 0., self.resolution[1] / 2.], - [0., self.focal_length, self.resolution[0] / 2.], - [0., 0., 1.]]) + self.K = np.array( + [ + [self.focal_length, 0., self.resolution[1] / 2.], + [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.] + ] + ) self.colors_dict = { 'red': np.array([0.5, 0.2, 0.2]), 'pink': np.array([0.7, 0.5, 0.5]), @@ -281,16 +300,29 @@ class OpenDRenderer: } self.renderer = ColoredRenderer() self.faces = get_smpl_faces() - + def reset_res(self, resolution): self.resolution = (resolution[0] * self.ratio, resolution[1] * self.ratio) - self.K = np.array([[self.focal_length, 0., self.resolution[1] / 2.], - [0., self.focal_length, self.resolution[0] / 2.], - [0., 0., 1.]]) + self.K = np.array( + [ + [self.focal_length, 0., self.resolution[1] / 2.], + [0., self.focal_length, self.resolution[0] / 2.], [0., 0., 1.] + ] + ) - def __call__(self, verts, faces=None, color=None, color_type='white', R=None, mesh_filename=None, - img=np.zeros((224, 224, 3)), cam=np.array([1, 0, 0]), - rgba=False, addlight=True): + def __call__( + self, + verts, + faces=None, + color=None, + color_type='white', + R=None, + mesh_filename=None, + img=np.zeros((224, 224, 3)), + cam=np.array([1, 0, 0]), + rgba=False, + addlight=True + ): '''Render mesh using OpenDR verts: shape - (V, 3) faces: shape - (F, 3) @@ -307,18 +339,18 @@ class OpenDRenderer: f = np.array([K[0, 0], K[1, 1]]) c = np.array([K[0, 2], K[1, 2]]) - + if faces is None: faces = self.faces if len(cam) == 4: t = np.array([cam[2], cam[3], 2 * K[0, 0] / (w * cam[0] + 1e-9)]) elif len(cam) == 3: t = np.array([cam[1], cam[2], 2 * K[0, 0] / (w * cam[0] + 1e-9)]) - + rn.camera = ProjectPoints(rt=np.array([0, 0, 0]), t=t, f=f, c=c, k=np.zeros(5)) rn.frustum = {'near': 1., 'far': 1000., 'width': w, 'height': h} - albedo = np.ones_like(verts)*.9 + albedo = np.ones_like(verts) * .9 if color is not None: color0 = np.array(color) @@ -343,7 +375,7 @@ class OpenDRenderer: rn.set(v=verts, f=faces, vc=color, bgcolor=np.zeros(3)) if addlight: - yrot = np.radians(120) # angle of lights + yrot = np.radians(120) # angle of lights # # 1. 1. 0.7 rn.vc = LambertianPointLight( f=rn.f, @@ -351,7 +383,8 @@ class OpenDRenderer: num_verts=len(rn.v), light_pos=rotateY(np.array([-200, -100, -100]), yrot), vc=albedo, - light_color=color0) + light_color=color0 + ) # Construct Left Light rn.vc += LambertianPointLight( @@ -360,7 +393,8 @@ class OpenDRenderer: num_verts=len(rn.v), light_pos=rotateY(np.array([800, 10, 300]), yrot), vc=albedo, - light_color=color1) + light_color=color1 + ) # Construct Right Light rn.vc += LambertianPointLight( @@ -369,7 +403,8 @@ class OpenDRenderer: num_verts=len(rn.v), light_pos=rotateY(np.array([-500, 500, 1000]), yrot), vc=albedo, - light_color=color2) + light_color=color2 + ) rendered_image = rn.r visibility_image = rn.visibility_image @@ -379,12 +414,16 @@ class OpenDRenderer: return_img = [] for item in image_list: if self.ratio != 1: - img_resized = resize(item, (item.shape[0] * self.ratio, item.shape[1] * self.ratio), anti_aliasing=True) + img_resized = resize( + item, (item.shape[0] * self.ratio, item.shape[1] * self.ratio), + anti_aliasing=True + ) else: img_resized = item / 255. try: - img_resized[visibility_image != (2**32 - 1)] = rendered_image[visibility_image != (2**32 - 1)] + img_resized[visibility_image != (2**32 - 1) + ] = rendered_image[visibility_image != (2**32 - 1)] except: logger.warning('Can not render mesh.') @@ -407,34 +446,40 @@ class OpenDRenderer: # https://github.com/classner/up/blob/master/up_tools/camera.py def rotateY(points, angle): """Rotate all points in a 2D array around the y axis.""" - ry = np.array([ - [np.cos(angle), 0., np.sin(angle)], - [0., 1., 0. ], - [-np.sin(angle), 0., np.cos(angle)] - ]) + ry = np.array( + [[np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], [-np.sin(angle), 0., + np.cos(angle)]] + ) return np.dot(points, ry) -def rotateX( points, angle ): + +def rotateX(points, angle): """Rotate all points in a 2D array around the x axis.""" - rx = np.array([ - [1., 0., 0. ], - [0., np.cos(angle), -np.sin(angle)], - [0., np.sin(angle), np.cos(angle) ] - ]) + rx = np.array( + [[1., 0., 0.], [0., np.cos(angle), -np.sin(angle)], [0., np.sin(angle), + np.cos(angle)]] + ) return np.dot(points, rx) -def rotateZ( points, angle ): + +def rotateZ(points, angle): """Rotate all points in a 2D array around the z axis.""" - rz = np.array([ - [np.cos(angle), -np.sin(angle), 0. ], - [np.sin(angle), np.cos(angle), 0. ], - [0., 0., 1. ] - ]) + rz = np.array( + [[np.cos(angle), -np.sin(angle), 0.], [np.sin(angle), np.cos(angle), 0.], [0., 0., 1.]] + ) return np.dot(points, rz) class IUV_Renderer(object): - def __init__(self, focal_length=5000., orig_size=224, output_size=56, mode='iuv', device=torch.device('cuda'), mesh_type='smpl'): + def __init__( + self, + focal_length=5000., + orig_size=224, + output_size=56, + mode='iuv', + device=torch.device('cuda'), + mesh_type='smpl' + ): self.focal_length = focal_length self.orig_size = orig_size @@ -449,7 +494,9 @@ class IUV_Renderer(object): faces = DP.FacesDensePose faces = faces[None, :, :] - self.faces = torch.from_numpy(faces.astype(np.int32)) # [1, 13774, 3], torch.int32 + self.faces = torch.from_numpy( + faces.astype(np.int32) + ) # [1, 13774, 3], torch.int32 num_part = float(np.max(DP.FaceIndices)) self.num_part = num_part @@ -468,13 +515,22 @@ class IUV_Renderer(object): np.save(dp_vert_pid_fname, np.array(dp_vert_pid)) textures_vts = np.array( - [(dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i]) for i in - range(len(vert_mapping))]) - self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32)) # (1, 7829, 3) + [ + (dp_vert_pid[i] / num_part, DP.U_norm[i], DP.V_norm[i]) + for i in range(len(vert_mapping)) + ] + ) + self.textures_vts = torch.from_numpy( + textures_vts[None].astype(np.float32) + ) # (1, 7829, 3) elif mode == 'pncc': self.vert_mapping = None - self.faces = torch.from_numpy(get_model_faces(mesh_type)[None].astype(np.int32)) # mano: torch.Size([1, 1538, 3]) - textures_vts = get_model_tpose(mesh_type).unsqueeze(0) # mano: torch.Size([1, 778, 3]) + self.faces = torch.from_numpy( + get_model_faces(mesh_type)[None].astype(np.int32) + ) # mano: torch.Size([1, 1538, 3]) + textures_vts = get_model_tpose(mesh_type).unsqueeze( + 0 + ) # mano: torch.Size([1, 778, 3]) texture_min = torch.min(textures_vts) - 0.001 texture_range = torch.max(textures_vts) - texture_min + 0.001 @@ -485,7 +541,11 @@ class IUV_Renderer(object): self.faces = torch.from_numpy(get_smpl_faces().astype(np.int32)[None]) - with open(os.path.join(path_config.SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), 'rb') as json_file: + with open( + os.path.join( + path_config.SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model) + ), 'rb' + ) as json_file: smpl_part_id = json.load(json_file) v_id = [] @@ -509,9 +569,12 @@ class IUV_Renderer(object): # range(n_verts)]) self.textures_vts = torch.from_numpy(textures_vts[None].astype(np.float32)) - K = np.array([[self.focal_length, 0., self.orig_size / 2.], - [0., self.focal_length, self.orig_size / 2.], - [0., 0., 1.]]) + K = np.array( + [ + [self.focal_length, 0., self.orig_size / 2.], + [0., self.focal_length, self.orig_size / 2.], [0., 0., 1.] + ] + ) R = np.array([[-1., 0., 0.], [0., -1., 0.], [0., 0., 1.]]) @@ -540,26 +603,27 @@ class IUV_Renderer(object): raster_settings = RasterizationSettings( image_size=output_size, - blur_radius=0, + blur_radius=0, faces_per_pixel=1, ) self.renderer = MeshRenderer( - rasterizer=MeshRasterizer( - raster_settings=raster_settings - ), - shader=HardFlatShader( - device=self.device, - lights=lights, - blend_params=BlendParams(background_color=[0, 0, 0], sigma=0.0, gamma=0.0) - ) + rasterizer=MeshRasterizer(raster_settings=raster_settings), + shader=HardFlatShader( + device=self.device, + lights=lights, + blend_params=BlendParams(background_color=[0, 0, 0], sigma=0.0, gamma=0.0) ) + ) def camera_matrix(self, cam): batch_size = cam.size(0) K = self.K.repeat(batch_size, 1, 1) R = self.R.repeat(batch_size, 1, 1) - t = torch.stack([-cam[:, 1], -cam[:, 2], 2 * self.focal_length/(self.orig_size * cam[:, 0] + 1e-9)], dim=-1) + t = torch.stack( + [-cam[:, 1], -cam[:, 2], 2 * self.focal_length / (self.orig_size * cam[:, 0] + 1e-9)], + dim=-1 + ) if cam.is_cuda: # device_id = cam.get_device() @@ -580,9 +644,18 @@ class IUV_Renderer(object): vertices = verts[:, self.vert_mapping, :] mesh = Meshes(vertices, self.faces.to(verts.device).expand(batch_size, -1, -1)) - mesh.textures = TexturesVertex(verts_features=self.textures_vts.to(verts.device).expand(batch_size, -1, -1)) + mesh.textures = TexturesVertex( + verts_features=self.textures_vts.to(verts.device).expand(batch_size, -1, -1) + ) - cameras = PerspectiveCameras(device=verts.device, R=R, T=t, K=K, in_ndc=False, image_size=[(self.orig_size, self.orig_size)]) + cameras = PerspectiveCameras( + device=verts.device, + R=R, + T=t, + K=K, + in_ndc=False, + image_size=[(self.orig_size, self.orig_size)] + ) iuv_image = self.renderer(mesh, cameras=cameras) iuv_image = iuv_image[..., :3].permute(0, 3, 1, 2) diff --git a/lib/pymafx/utils/sample_mesh.py b/lib/pymafx/utils/sample_mesh.py index 9d8833cf0642394ba6ff9ba86bad11c278ebdd01..2599bee12d2577b6826ea8bfad8c937f2bcc2db2 100644 --- a/lib/pymafx/utils/sample_mesh.py +++ b/lib/pymafx/utils/sample_mesh.py @@ -3,7 +3,17 @@ import trimesh import numpy as np from .utils.libmesh import check_mesh_contains -def get_occ_gt(in_path=None, vertices=None, faces=None, pts_num=1000, points_sigma=0.01, with_dp=False, points=None, extra_points=None): + +def get_occ_gt( + in_path=None, + vertices=None, + faces=None, + pts_num=1000, + points_sigma=0.01, + with_dp=False, + points=None, + extra_points=None +): if in_path is not None: mesh = trimesh.load(in_path, process=False) print(type(mesh.vertices), mesh.vertices.shape, mesh.faces.shape) @@ -27,7 +37,7 @@ def get_occ_gt(in_path=None, vertices=None, faces=None, pts_num=1000, points_sig points_surface, index_surface = mesh.sample(n_points_surface, return_index=True) points_surface += points_sigma * np.random.randn(n_points_surface, 3) points = np.concatenate([points_uniform, points_surface], axis=0) - + if extra_points is not None: extra_points += points_sigma * np.random.randn(len(extra_points), 3) points = np.concatenate([points, extra_points], axis=0) diff --git a/lib/pymafx/utils/saver.py b/lib/pymafx/utils/saver.py index 417db9bc59684579a4f0d9778b8c5fd251a2d8f1..6a6bd3a184cc658dbc666ad2dcf3bc15d8cc427b 100644 --- a/lib/pymafx/utils/saver.py +++ b/lib/pymafx/utils/saver.py @@ -3,8 +3,10 @@ import os import torch import datetime import logging + logger = logging.getLogger(__name__) + class CheckpointSaver(): """Class that handles saving and loading checkpoints during training.""" def __init__(self, save_dir, save_steps=1000, overwrite=False): @@ -22,26 +24,41 @@ class CheckpointSaver(): return False if self.latest_checkpoint is None else True else: return os.path.isfile(checkpoint_file) - - def save_checkpoint(self, models, optimizers, epoch, batch_idx, batch_size, - total_step_count, is_best=False, save_by_step=False, interval=5, with_optimizer=True): + + def save_checkpoint( + self, + models, + optimizers, + epoch, + batch_idx, + batch_size, + total_step_count, + is_best=False, + save_by_step=False, + interval=5, + with_optimizer=True + ): """Save checkpoint.""" timestamp = datetime.datetime.now() if self.overwrite: checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt')) elif save_by_step: - checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count))) + checkpoint_filename = os.path.abspath( + os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count)) + ) else: if epoch % interval == 0: - checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt')) + checkpoint_filename = os.path.abspath( + os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt') + ) else: checkpoint_filename = None - + checkpoint = {} for model in models: model_dict = models[model].state_dict() for k in list(model_dict.keys()): - if '.smpl.' in k: + if '.smpl.' in k: del model_dict[k] checkpoint[model] = model_dict if with_optimizer: @@ -56,7 +73,7 @@ class CheckpointSaver(): if checkpoint_filename is not None: torch.save(checkpoint, checkpoint_filename) print('Saving checkpoint file [' + checkpoint_filename + ']') - if is_best: # save the best + if is_best: # save the best checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt')) torch.save(checkpoint, checkpoint_filename) print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) @@ -64,7 +81,6 @@ class CheckpointSaver(): torch.save(checkpoint, checkpoint_filename) print('Saved checkpoint file [' + checkpoint_filename + ']') - def load_checkpoint(self, models, optimizers, checkpoint_file=None): """Load a checkpoint.""" if checkpoint_file is None: @@ -74,8 +90,10 @@ class CheckpointSaver(): for model in models: if model in checkpoint: model_dict = models[model].state_dict() - pretrained_dict = {k: v for k, v in checkpoint[model].items() - if k in model_dict.keys()} + pretrained_dict = { + k: v + for k, v in checkpoint[model].items() if k in model_dict.keys() + } model_dict.update(pretrained_dict) models[model].load_state_dict(model_dict) @@ -83,20 +101,23 @@ class CheckpointSaver(): for optimizer in optimizers: if optimizer in checkpoint: optimizers[optimizer].load_state_dict(checkpoint[optimizer]) - return {'epoch': checkpoint['epoch'], - 'batch_idx': checkpoint['batch_idx'], - 'batch_size': checkpoint['batch_size'], - 'total_step_count': checkpoint['total_step_count']} + return { + 'epoch': checkpoint['epoch'], + 'batch_idx': checkpoint['batch_idx'], + 'batch_size': checkpoint['batch_size'], + 'total_step_count': checkpoint['total_step_count'] + } def get_latest_checkpoint(self): """Get filename of latest checkpoint if it exists.""" - checkpoint_list = [] + checkpoint_list = [] for dirpath, dirnames, filenames in os.walk(self.save_dir): for filename in filenames: if filename.endswith('.pt'): checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename))) # sort import re + def atof(text): try: retval = float(text) @@ -111,8 +132,8 @@ class CheckpointSaver(): (See Toothy's implementation in the comments) float regex comes from https://stackoverflow.com/a/12643073/190597 ''' - return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ] - + return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)] + checkpoint_list.sort(key=natural_keys) - self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1] + self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1] return diff --git a/lib/pymafx/utils/segms.py b/lib/pymafx/utils/segms.py index 651dd0072e93660bdd0faf565c3358353ef0664b..44c617529d67323a8664c3e00872e5db091b8be6 100644 --- a/lib/pymafx/utils/segms.py +++ b/lib/pymafx/utils/segms.py @@ -32,249 +32,237 @@ import pycocotools.mask as mask_util def GetDensePoseMask(Polys): - MaskGen = np.zeros([256, 256]) - for i in range(1, 15): - if (Polys[i - 1]): - current_mask = mask_util.decode(Polys[i - 1]) - MaskGen[current_mask > 0] = i - return MaskGen + MaskGen = np.zeros([256, 256]) + for i in range(1, 15): + if (Polys[i - 1]): + current_mask = mask_util.decode(Polys[i - 1]) + MaskGen[current_mask > 0] = i + return MaskGen def flip_segms(segms, height, width): - """Left/right flip each mask in a list of masks.""" - - def _flip_poly(poly, width): - flipped_poly = np.array(poly) - flipped_poly[0::2] = width - np.array(poly[0::2]) - 1 - return flipped_poly.tolist() - - def _flip_rle(rle, height, width): - if 'counts' in rle and type(rle['counts']) == list: - # Magic RLE format handling painfully discovered by looking at the - # COCO API showAnns function. - rle = mask_util.frPyObjects([rle], height, width) - mask = mask_util.decode(rle) - mask = mask[:, ::-1, :] - rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) - return rle - - flipped_segms = [] - for segm in segms: - if type(segm) == list: - # Polygon format - flipped_segms.append([_flip_poly(poly, width) for poly in segm]) - else: - # RLE format - assert type(segm) == dict - flipped_segms.append(_flip_rle(segm, height, width)) - return flipped_segms + """Left/right flip each mask in a list of masks.""" + def _flip_poly(poly, width): + flipped_poly = np.array(poly) + flipped_poly[0::2] = width - np.array(poly[0::2]) - 1 + return flipped_poly.tolist() + + def _flip_rle(rle, height, width): + if 'counts' in rle and type(rle['counts']) == list: + # Magic RLE format handling painfully discovered by looking at the + # COCO API showAnns function. + rle = mask_util.frPyObjects([rle], height, width) + mask = mask_util.decode(rle) + mask = mask[:, ::-1, :] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + flipped_segms = [] + for segm in segms: + if type(segm) == list: + # Polygon format + flipped_segms.append([_flip_poly(poly, width) for poly in segm]) + else: + # RLE format + assert type(segm) == dict + flipped_segms.append(_flip_rle(segm, height, width)) + return flipped_segms def polys_to_mask(polygons, height, width): - """Convert from the COCO polygon segmentation format to a binary mask + """Convert from the COCO polygon segmentation format to a binary mask encoded as a 2D array of data type numpy.float32. The polygon segmentation is understood to be enclosed inside a height x width image. The resulting mask is therefore of shape (height, width). """ - rle = mask_util.frPyObjects(polygons, height, width) - mask = np.array(mask_util.decode(rle), dtype=np.float32) - # Flatten in case polygons was a list - mask = np.sum(mask, axis=2) - mask = np.array(mask > 0, dtype=np.float32) - return mask + rle = mask_util.frPyObjects(polygons, height, width) + mask = np.array(mask_util.decode(rle), dtype=np.float32) + # Flatten in case polygons was a list + mask = np.sum(mask, axis=2) + mask = np.array(mask > 0, dtype=np.float32) + return mask def mask_to_bbox(mask): - """Compute the tight bounding box of a binary mask.""" - xs = np.where(np.sum(mask, axis=0) > 0)[0] - ys = np.where(np.sum(mask, axis=1) > 0)[0] + """Compute the tight bounding box of a binary mask.""" + xs = np.where(np.sum(mask, axis=0) > 0)[0] + ys = np.where(np.sum(mask, axis=1) > 0)[0] - if len(xs) == 0 or len(ys) == 0: - return None + if len(xs) == 0 or len(ys) == 0: + return None - x0 = xs[0] - x1 = xs[-1] - y0 = ys[0] - y1 = ys[-1] - return np.array((x0, y0, x1, y1), dtype=np.float32) + x0 = xs[0] + x1 = xs[-1] + y0 = ys[0] + y1 = ys[-1] + return np.array((x0, y0, x1, y1), dtype=np.float32) def polys_to_mask_wrt_box(polygons, box, M): - """Convert from the COCO polygon segmentation format to a binary mask + """Convert from the COCO polygon segmentation format to a binary mask encoded as a 2D array of data type numpy.float32. The polygon segmentation is understood to be enclosed in the given box and rasterized to an M x M mask. The resulting mask is therefore of shape (M, M). """ - w = box[2] - box[0] - h = box[3] - box[1] + w = box[2] - box[0] + h = box[3] - box[1] - w = np.maximum(w, 1) - h = np.maximum(h, 1) + w = np.maximum(w, 1) + h = np.maximum(h, 1) - polygons_norm = [] - for poly in polygons: - p = np.array(poly, dtype=np.float32) - p[0::2] = (p[0::2] - box[0]) * M / w - p[1::2] = (p[1::2] - box[1]) * M / h - polygons_norm.append(p) + polygons_norm = [] + for poly in polygons: + p = np.array(poly, dtype=np.float32) + p[0::2] = (p[0::2] - box[0]) * M / w + p[1::2] = (p[1::2] - box[1]) * M / h + polygons_norm.append(p) - rle = mask_util.frPyObjects(polygons_norm, M, M) - mask = np.array(mask_util.decode(rle), dtype=np.float32) - # Flatten in case polygons was a list - mask = np.sum(mask, axis=2) - mask = np.array(mask > 0, dtype=np.float32) - return mask + rle = mask_util.frPyObjects(polygons_norm, M, M) + mask = np.array(mask_util.decode(rle), dtype=np.float32) + # Flatten in case polygons was a list + mask = np.sum(mask, axis=2) + mask = np.array(mask > 0, dtype=np.float32) + return mask def polys_to_boxes(polys): - """Convert a list of polygons into an array of tight bounding boxes.""" - boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32) - for i in range(len(polys)): - poly = polys[i] - x0 = min(min(p[::2]) for p in poly) - x1 = max(max(p[::2]) for p in poly) - y0 = min(min(p[1::2]) for p in poly) - y1 = max(max(p[1::2]) for p in poly) - boxes_from_polys[i, :] = [x0, y0, x1, y1] - - return boxes_from_polys - - -def rle_mask_voting(top_masks, - all_masks, - all_dets, - iou_thresh, - binarize_thresh, - method='AVG'): - """Returns new masks (in correspondence with `top_masks`) by combining + """Convert a list of polygons into an array of tight bounding boxes.""" + boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32) + for i in range(len(polys)): + poly = polys[i] + x0 = min(min(p[::2]) for p in poly) + x1 = max(max(p[::2]) for p in poly) + y0 = min(min(p[1::2]) for p in poly) + y1 = max(max(p[1::2]) for p in poly) + boxes_from_polys[i, :] = [x0, y0, x1, y1] + + return boxes_from_polys + + +def rle_mask_voting(top_masks, all_masks, all_dets, iou_thresh, binarize_thresh, method='AVG'): + """Returns new masks (in correspondence with `top_masks`) by combining multiple overlapping masks coming from the pool of `all_masks`. Two methods for combining masks are supported: 'AVG' uses a weighted average of overlapping mask pixels; 'UNION' takes the union of all mask pixels. """ - if len(top_masks) == 0: - return - - all_not_crowd = [False] * len(all_masks) - top_to_all_overlaps = mask_util.iou(top_masks, all_masks, all_not_crowd) - decoded_all_masks = [ - np.array(mask_util.decode(rle), dtype=np.float32) for rle in all_masks - ] - decoded_top_masks = [ - np.array(mask_util.decode(rle), dtype=np.float32) for rle in top_masks - ] - all_boxes = all_dets[:, :4].astype(np.int32) - all_scores = all_dets[:, 4] - - # Fill box support with weights - mask_shape = decoded_all_masks[0].shape - mask_weights = np.zeros((len(all_masks), mask_shape[0], mask_shape[1])) - for k in range(len(all_masks)): - ref_box = all_boxes[k] - x_0 = max(ref_box[0], 0) - x_1 = min(ref_box[2] + 1, mask_shape[1]) - y_0 = max(ref_box[1], 0) - y_1 = min(ref_box[3] + 1, mask_shape[0]) - mask_weights[k, y_0:y_1, x_0:x_1] = all_scores[k] - mask_weights = np.maximum(mask_weights, 1e-5) - - top_segms_out = [] - for k in range(len(top_masks)): - # Corner case of empty mask - if decoded_top_masks[k].sum() == 0: - top_segms_out.append(top_masks[k]) - continue - - inds_to_vote = np.where(top_to_all_overlaps[k] >= iou_thresh)[0] - # Only matches itself - if len(inds_to_vote) == 1: - top_segms_out.append(top_masks[k]) - continue - - masks_to_vote = [decoded_all_masks[i] for i in inds_to_vote] - if method == 'AVG': - ws = mask_weights[inds_to_vote] - soft_mask = np.average(masks_to_vote, axis=0, weights=ws) - mask = np.array(soft_mask > binarize_thresh, dtype=np.uint8) - elif method == 'UNION': - # Any pixel that's on joins the mask - soft_mask = np.sum(masks_to_vote, axis=0) - mask = np.array(soft_mask > 1e-5, dtype=np.uint8) - else: - raise NotImplementedError('Method {} is unknown'.format(method)) - rle = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0] - top_segms_out.append(rle) - - return top_segms_out + if len(top_masks) == 0: + return + + all_not_crowd = [False] * len(all_masks) + top_to_all_overlaps = mask_util.iou(top_masks, all_masks, all_not_crowd) + decoded_all_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in all_masks] + decoded_top_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in top_masks] + all_boxes = all_dets[:, :4].astype(np.int32) + all_scores = all_dets[:, 4] + + # Fill box support with weights + mask_shape = decoded_all_masks[0].shape + mask_weights = np.zeros((len(all_masks), mask_shape[0], mask_shape[1])) + for k in range(len(all_masks)): + ref_box = all_boxes[k] + x_0 = max(ref_box[0], 0) + x_1 = min(ref_box[2] + 1, mask_shape[1]) + y_0 = max(ref_box[1], 0) + y_1 = min(ref_box[3] + 1, mask_shape[0]) + mask_weights[k, y_0:y_1, x_0:x_1] = all_scores[k] + mask_weights = np.maximum(mask_weights, 1e-5) + + top_segms_out = [] + for k in range(len(top_masks)): + # Corner case of empty mask + if decoded_top_masks[k].sum() == 0: + top_segms_out.append(top_masks[k]) + continue + + inds_to_vote = np.where(top_to_all_overlaps[k] >= iou_thresh)[0] + # Only matches itself + if len(inds_to_vote) == 1: + top_segms_out.append(top_masks[k]) + continue + + masks_to_vote = [decoded_all_masks[i] for i in inds_to_vote] + if method == 'AVG': + ws = mask_weights[inds_to_vote] + soft_mask = np.average(masks_to_vote, axis=0, weights=ws) + mask = np.array(soft_mask > binarize_thresh, dtype=np.uint8) + elif method == 'UNION': + # Any pixel that's on joins the mask + soft_mask = np.sum(masks_to_vote, axis=0) + mask = np.array(soft_mask > 1e-5, dtype=np.uint8) + else: + raise NotImplementedError('Method {} is unknown'.format(method)) + rle = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0] + top_segms_out.append(rle) + + return top_segms_out def rle_mask_nms(masks, dets, thresh, mode='IOU'): - """Performs greedy non-maximum suppression based on an overlap measurement + """Performs greedy non-maximum suppression based on an overlap measurement between masks. The type of measurement is determined by `mode` and can be either 'IOU' (standard intersection over union) or 'IOMA' (intersection over mininum area). """ - if len(masks) == 0: - return [] - if len(masks) == 1: - return [0] - - if mode == 'IOU': - # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(union(m1, m2)) - all_not_crowds = [False] * len(masks) - ious = mask_util.iou(masks, masks, all_not_crowds) - elif mode == 'IOMA': - # Computes ious[m1, m2] = area(intersect(m1, m2)) / min(area(m1), area(m2)) - all_crowds = [True] * len(masks) - # ious[m1, m2] = area(intersect(m1, m2)) / area(m2) - ious = mask_util.iou(masks, masks, all_crowds) - # ... = max(area(intersect(m1, m2)) / area(m2), - # area(intersect(m2, m1)) / area(m1)) - ious = np.maximum(ious, ious.transpose()) - elif mode == 'CONTAINMENT': - # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(m2) - # Which measures how much m2 is contained inside m1 - all_crowds = [True] * len(masks) - ious = mask_util.iou(masks, masks, all_crowds) - else: - raise NotImplementedError('Mode {} is unknown'.format(mode)) - - scores = dets[:, 4] - order = np.argsort(-scores) - - keep = [] - while order.size > 0: - i = order[0] - keep.append(i) - ovr = ious[i, order[1:]] - inds_to_keep = np.where(ovr <= thresh)[0] - order = order[inds_to_keep + 1] - - return keep + if len(masks) == 0: + return [] + if len(masks) == 1: + return [0] + + if mode == 'IOU': + # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(union(m1, m2)) + all_not_crowds = [False] * len(masks) + ious = mask_util.iou(masks, masks, all_not_crowds) + elif mode == 'IOMA': + # Computes ious[m1, m2] = area(intersect(m1, m2)) / min(area(m1), area(m2)) + all_crowds = [True] * len(masks) + # ious[m1, m2] = area(intersect(m1, m2)) / area(m2) + ious = mask_util.iou(masks, masks, all_crowds) + # ... = max(area(intersect(m1, m2)) / area(m2), + # area(intersect(m2, m1)) / area(m1)) + ious = np.maximum(ious, ious.transpose()) + elif mode == 'CONTAINMENT': + # Computes ious[m1, m2] = area(intersect(m1, m2)) / area(m2) + # Which measures how much m2 is contained inside m1 + all_crowds = [True] * len(masks) + ious = mask_util.iou(masks, masks, all_crowds) + else: + raise NotImplementedError('Mode {} is unknown'.format(mode)) + + scores = dets[:, 4] + order = np.argsort(-scores) + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + ovr = ious[i, order[1:]] + inds_to_keep = np.where(ovr <= thresh)[0] + order = order[inds_to_keep + 1] + + return keep def rle_masks_to_boxes(masks): - """Computes the bounding box of each mask in a list of RLE encoded masks.""" - if len(masks) == 0: - return [] - - decoded_masks = [ - np.array(mask_util.decode(rle), dtype=np.float32) for rle in masks - ] - - def get_bounds(flat_mask): - inds = np.where(flat_mask > 0)[0] - return inds.min(), inds.max() - - boxes = np.zeros((len(decoded_masks), 4)) - keep = [True] * len(decoded_masks) - for i, mask in enumerate(decoded_masks): - if mask.sum() == 0: - keep[i] = False - continue - flat_mask = mask.sum(axis=0) - x0, x1 = get_bounds(flat_mask) - flat_mask = mask.sum(axis=1) - y0, y1 = get_bounds(flat_mask) - boxes[i, :] = (x0, y0, x1, y1) - - return boxes, np.where(keep)[0] + """Computes the bounding box of each mask in a list of RLE encoded masks.""" + if len(masks) == 0: + return [] + + decoded_masks = [np.array(mask_util.decode(rle), dtype=np.float32) for rle in masks] + + def get_bounds(flat_mask): + inds = np.where(flat_mask > 0)[0] + return inds.min(), inds.max() + + boxes = np.zeros((len(decoded_masks), 4)) + keep = [True] * len(decoded_masks) + for i, mask in enumerate(decoded_masks): + if mask.sum() == 0: + keep[i] = False + continue + flat_mask = mask.sum(axis=0) + x0, x1 = get_bounds(flat_mask) + flat_mask = mask.sum(axis=1) + y0, y1 = get_bounds(flat_mask) + boxes[i, :] = (x0, y0, x1, y1) + + return boxes, np.where(keep)[0] diff --git a/lib/pymafx/utils/smooth_bbox.py b/lib/pymafx/utils/smooth_bbox.py index 1d31f74dbfad1cfc5eb4e32da31490106d16d510..4393320e7f50128d6838d99c76b5d0f8f45f6efc 100644 --- a/lib/pymafx/utils/smooth_bbox.py +++ b/lib/pymafx/utils/smooth_bbox.py @@ -94,8 +94,11 @@ def get_all_bbox_params(kps, vis_thresh=2): previous = bbox_params[-1] # This will be 3x(n+2) interpolated = np.array( - [np.linspace(prev, curr, num_to_interpolate + 2) - for prev, curr in zip(previous, bbox_param)]) + [ + np.linspace(prev, curr, num_to_interpolate + 2) + for prev, curr in zip(previous, bbox_param) + ] + ) bbox_params = np.vstack((bbox_params, interpolated.T[1:-1])) num_to_interpolate = 0 bbox_params = np.vstack((bbox_params, bbox_param)) @@ -116,6 +119,5 @@ def smooth_bbox_params(bbox_params, kernel_size=11, sigma=8): Returns: Smoothed bounding box parameters (Nx3). """ - smoothed = np.array([signal.medfilt(param, kernel_size) - for param in bbox_params.T]).T + smoothed = np.array([signal.medfilt(param, kernel_size) for param in bbox_params.T]).T return np.array([gaussian_filter1d(traj, sigma) for traj in smoothed.T]).T diff --git a/lib/pymafx/utils/transforms.py b/lib/pymafx/utils/transforms.py index a283e5122bff1b9eb61b1a9092ead15200a355f8..25534674631d40b8b263b242d05339443b169dcb 100644 --- a/lib/pymafx/utils/transforms.py +++ b/lib/pymafx/utils/transforms.py @@ -43,7 +43,7 @@ def fliplr_joints(joints, joints_vis, width, matched_parts): joints_vis[pair[0], :], joints_vis[pair[1], :] = \ joints_vis[pair[1], :], joints_vis[pair[0], :].copy() - return joints*joints_vis, joints_vis + return joints * joints_vis, joints_vis def transform_preds(coords, center, scale, output_size): @@ -55,8 +55,7 @@ def transform_preds(coords, center, scale, output_size): def get_affine_transform( - center, scale, rot, output_size, - shift=np.array([0, 0], dtype=np.float32), inv=0 + center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0 ): if not isinstance(scale, np.ndarray) and not isinstance(scale, list): # print(scale) @@ -114,8 +113,7 @@ def crop(img, center, scale, output_size, rot=0): trans = get_affine_transform(center, scale, rot, output_size) dst_img = cv2.warpAffine( - img, trans, (int(output_size[0]), int(output_size[1])), - flags=cv2.INTER_LINEAR + img, trans, (int(output_size[0]), int(output_size[1])), flags=cv2.INTER_LINEAR ) return dst_img diff --git a/lib/pymafx/utils/uv_vis.py b/lib/pymafx/utils/uv_vis.py index 8dd0d3fd75cfa06c9cd2b6fabac8ab47b8eae833..86fdd33ddee774c2bbe02478b2d74f53f8522256 100644 --- a/lib/pymafx/utils/uv_vis.py +++ b/lib/pymafx/utils/uv_vis.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from skimage.transform import resize # Use a non-interactive backend import matplotlib + matplotlib.use('Agg') from .renderer import OpenDRenderer, PyRenderer @@ -37,8 +38,10 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N for part_id in range(1, K): CurrentU = U_uv[batch_id, part_id] CurrentV = V_uv[batch_id, part_id] - output[1, Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id] - output[2, Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id] + output[1, + Index_UV_max[batch_id] == part_id] = CurrentU[Index_UV_max[batch_id] == part_id] + output[2, + Index_UV_max[batch_id] == part_id] = CurrentV[Index_UV_max[batch_id] == part_id] if uv_rois is None: outputs.append(output.unsqueeze(0)) @@ -53,19 +56,34 @@ def iuv_map2img(U_uv, V_uv, Index_UV, AnnIndex=None, uv_rois=None, ind_mapping=N new_size = [heatmap_size, max(int(heatmap_size * aspect_ratio), 1)] output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') paddingleft = int(0.5 * (heatmap_size - new_size[1])) - output = F.pad(output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0)) + output = F.pad( + output, pad=(paddingleft, heatmap_size - new_size[1] - paddingleft, 0, 0) + ) else: new_size = [max(int(heatmap_size / aspect_ratio), 1), heatmap_size] output = F.interpolate(output.unsqueeze(0), size=new_size, mode='nearest') paddingtop = int(0.5 * (heatmap_size - new_size[0])) - output = F.pad(output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop)) + output = F.pad( + output, pad=(0, 0, paddingtop, heatmap_size - new_size[0] - paddingtop) + ) outputs.append(output) return torch.cat(outputs, dim=0) -def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, image_name, save_path, opt, ratio=1): +def vis_smpl_iuv( + image, + cam_pred, + vert_pred, + face, + pred_uv, + vert_errors_batch, + image_name, + save_path, + opt, + ratio=1 +): # save_path = os.path.join('./notebooks/output/demo_results-wild', ids[f_id][0]) if not os.path.exists(save_path): @@ -82,9 +100,9 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i for draw_i in range(len(cam_pred)): err_val = '{:06d}_'.format(int(10 * vert_errors_batch[draw_i])) draw_name = err_val + image_name[draw_i] - K = np.array([[focal_length, 0., orig_size / 2.], - [0., focal_length, orig_size / 2.], - [0., 0., 1.]]) + K = np.array( + [[focal_length, 0., orig_size / 2.], [0., focal_length, orig_size / 2.], [0., 0., 1.]] + ) # img_orig, img_resized, img_smpl, render_smpl_rgba = dr_render( # image[draw_i], @@ -100,13 +118,14 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i mesh_filename = None img_orig = np.moveaxis(image[draw_i], 0, -1) - img_smpl, img_resized = dr_render(vert_pred[draw_i], - img=img_orig, - cam=cam_pred[draw_i], - iwp_mode=True, - scale_ratio=4., - mesh_filename=mesh_filename, - ) + img_smpl, img_resized = dr_render( + vert_pred[draw_i], + img=img_orig, + cam=cam_pred[draw_i], + iwp_mode=True, + scale_ratio=4., + mesh_filename=mesh_filename, + ) ones_img = np.ones(img_smpl.shape[:2]) * 255 ones_img = ones_img[:, :, None] @@ -117,7 +136,9 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i render_img = np.concatenate((img_resized_rgba, img_smpl_rgba), axis=1) render_img[render_img < 0] = 0 render_img[render_img > 255] = 255 - matplotlib.image.imsave(os.path.join(save_path, draw_name[:-4] + '.png'), render_img.astype(np.uint8)) + matplotlib.image.imsave( + os.path.join(save_path, draw_name[:-4] + '.png'), render_img.astype(np.uint8) + ) if pred_uv is not None: # estimated global IUV @@ -126,4 +147,6 @@ def vis_smpl_iuv(image, cam_pred, vert_pred, face, pred_uv, vert_errors_batch, i global_iuv = resize(global_iuv, img_resized.shape[:2]) global_iuv[global_iuv > 1] = 1 global_iuv[global_iuv < 0] = 0 - matplotlib.image.imsave(os.path.join(save_path, 'pred_uv_' + draw_name[:-4] + '.png'), global_iuv) \ No newline at end of file + matplotlib.image.imsave( + os.path.join(save_path, 'pred_uv_' + draw_name[:-4] + '.png'), global_iuv + ) diff --git a/lib/pymafx/utils/vis.py b/lib/pymafx/utils/vis.py index 873ee694b266e7ba875d548b13f63f65513d2ff7..5273707c05f66275150e7cb2d86f44dcf4c92223 100644 --- a/lib/pymafx/utils/vis.py +++ b/lib/pymafx/utils/vis.py @@ -17,7 +17,6 @@ # limitations under the License. ############################################################################## - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -36,14 +35,14 @@ from .imutils import normalize_2d_kp # Use a non-interactive backend import matplotlib + matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.patches import Polygon from mpl_toolkits.mplot3d import Axes3D from skimage.transform import resize -plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator - +plt.rcParams['pdf.fonttype'] = 42 # For editing in Adobe Illustrator _GRAY = (218, 227, 218) _GREEN = (18, 127, 15) @@ -52,24 +51,23 @@ _WHITE = (255, 255, 255) def get_colors(): colors = { - 'pink': np.array([197, 27, 125]), # L lower leg - 'light_pink': np.array([233, 163, 201]), # L upper leg - 'light_green': np.array([161, 215, 106]), # L lower arm - 'green': np.array([77, 146, 33]), # L upper arm - 'red': np.array([215, 48, 39]), # head - 'light_red': np.array([252, 146, 114]), # head - 'light_orange': np.array([252, 141, 89]), # chest - 'purple': np.array([118, 42, 131]), # R lower leg - 'light_purple': np.array([175, 141, 195]), # R upper - 'light_blue': np.array([145, 191, 219]), # R lower arm - 'blue': np.array([69, 117, 180]), # R upper arm - 'gray': np.array([130, 130, 130]), # - 'white': np.array([255, 255, 255]), # + 'pink': np.array([197, 27, 125]), # L lower leg + 'light_pink': np.array([233, 163, 201]), # L upper leg + 'light_green': np.array([161, 215, 106]), # L lower arm + 'green': np.array([77, 146, 33]), # L upper arm + 'red': np.array([215, 48, 39]), # head + 'light_red': np.array([252, 146, 114]), # head + 'light_orange': np.array([252, 141, 89]), # chest + 'purple': np.array([118, 42, 131]), # R lower leg + 'light_purple': np.array([175, 141, 195]), # R upper + 'light_blue': np.array([145, 191, 219]), # R lower arm + 'blue': np.array([69, 117, 180]), # R upper arm + 'gray': np.array([130, 130, 130]), # + 'white': np.array([255, 255, 255]), # } return colors - def kp_connections(keypoints): kp_lines = [ [keypoints.index('left_eye'), keypoints.index('right_eye')], @@ -77,15 +75,21 @@ def kp_connections(keypoints): [keypoints.index('right_eye'), keypoints.index('nose')], [keypoints.index('right_eye'), keypoints.index('right_ear')], [keypoints.index('left_eye'), keypoints.index('left_ear')], - [keypoints.index('right_shoulder'), keypoints.index('right_elbow')], - [keypoints.index('right_elbow'), keypoints.index('right_wrist')], - [keypoints.index('left_shoulder'), keypoints.index('left_elbow')], - [keypoints.index('left_elbow'), keypoints.index('left_wrist')], + [keypoints.index('right_shoulder'), + keypoints.index('right_elbow')], + [keypoints.index('right_elbow'), + keypoints.index('right_wrist')], + [keypoints.index('left_shoulder'), + keypoints.index('left_elbow')], + [keypoints.index('left_elbow'), + keypoints.index('left_wrist')], [keypoints.index('right_hip'), keypoints.index('right_knee')], - [keypoints.index('right_knee'), keypoints.index('right_ankle')], + [keypoints.index('right_knee'), + keypoints.index('right_ankle')], [keypoints.index('left_hip'), keypoints.index('left_knee')], [keypoints.index('left_knee'), keypoints.index('left_ankle')], - [keypoints.index('right_shoulder'), keypoints.index('left_shoulder')], + [keypoints.index('right_shoulder'), + keypoints.index('left_shoulder')], [keypoints.index('right_hip'), keypoints.index('left_hip')], ] return kp_lines @@ -130,16 +134,27 @@ def get_class_string(class_index, score, dataset): def vis_one_image( - im, im_name, output_dir, boxes, segms=None, keypoints=None, body_uv=None, thresh=0.9, - kp_thresh=2, dpi=200, box_alpha=0.0, dataset=None, show_class=False, - ext='pdf'): + im, + im_name, + output_dir, + boxes, + segms=None, + keypoints=None, + body_uv=None, + thresh=0.9, + kp_thresh=2, + dpi=200, + box_alpha=0.0, + dataset=None, + show_class=False, + ext='pdf' +): """Visual debugging of detections.""" if not os.path.exists(output_dir): os.makedirs(output_dir) if isinstance(boxes, list): - boxes, segms, keypoints, classes = convert_from_cls_format( - boxes, segms, keypoints) + boxes, segms, keypoints, classes = convert_from_cls_format(boxes, segms, keypoints) if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh: return @@ -176,21 +191,27 @@ def vis_one_image( print(dataset.classes[classes[i]], score) # show box (off by default, box_alpha=0.0) ax.add_patch( - plt.Rectangle((bbox[0], bbox[1]), - bbox[2] - bbox[0], - bbox[3] - bbox[1], - fill=False, edgecolor='g', - linewidth=0.5, alpha=box_alpha)) + plt.Rectangle( + (bbox[0], bbox[1]), + bbox[2] - bbox[0], + bbox[3] - bbox[1], + fill=False, + edgecolor='g', + linewidth=0.5, + alpha=box_alpha + ) + ) if show_class: ax.text( - bbox[0], bbox[1] - 2, + bbox[0], + bbox[1] - 2, get_class_string(classes[i], score, dataset), fontsize=3, family='serif', - bbox=dict( - facecolor='g', alpha=0.4, pad=0, edgecolor='none'), - color='white') + bbox=dict(facecolor='g', alpha=0.4, pad=0, edgecolor='none'), + color='white' + ) # show mask if segms is not None and len(segms) > i: @@ -205,15 +226,17 @@ def vis_one_image( img[:, :, c] = color_mask[c] e = masks[:, :, i] - _, contour, hier = cv2.findContours( - e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) for c in contour: polygon = Polygon( c.reshape((-1, 2)), - fill=True, facecolor=color_mask, - edgecolor='w', linewidth=1.2, - alpha=0.5) + fill=True, + facecolor=color_mask, + edgecolor='w', + linewidth=1.2, + alpha=0.5 + ) ax.add_patch(polygon) # show keypoints @@ -229,41 +252,39 @@ def vis_one_image( line = ax.plot(x, y) plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7) if kps[2, i1] > kp_thresh: - ax.plot( - kps[0, i1], kps[1, i1], '.', color=colors[l], - markersize=3.0, alpha=0.7) + ax.plot(kps[0, i1], kps[1, i1], '.', color=colors[l], markersize=3.0, alpha=0.7) if kps[2, i2] > kp_thresh: - ax.plot( - kps[0, i2], kps[1, i2], '.', color=colors[l], - markersize=3.0, alpha=0.7) + ax.plot(kps[0, i2], kps[1, i2], '.', color=colors[l], markersize=3.0, alpha=0.7) # add mid shoulder / mid hip for better visualization mid_shoulder = ( kps[:2, dataset_keypoints.index('right_shoulder')] + - kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0 + kps[:2, dataset_keypoints.index('left_shoulder')] + ) / 2.0 sc_mid_shoulder = np.minimum( kps[2, dataset_keypoints.index('right_shoulder')], - kps[2, dataset_keypoints.index('left_shoulder')]) + kps[2, dataset_keypoints.index('left_shoulder')] + ) mid_hip = ( kps[:2, dataset_keypoints.index('right_hip')] + - kps[:2, dataset_keypoints.index('left_hip')]) / 2.0 + kps[:2, dataset_keypoints.index('left_hip')] + ) / 2.0 sc_mid_hip = np.minimum( kps[2, dataset_keypoints.index('right_hip')], - kps[2, dataset_keypoints.index('left_hip')]) - if (sc_mid_shoulder > kp_thresh and - kps[2, dataset_keypoints.index('nose')] > kp_thresh): + kps[2, dataset_keypoints.index('left_hip')] + ) + if ( + sc_mid_shoulder > kp_thresh and kps[2, dataset_keypoints.index('nose')] > kp_thresh + ): x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]] y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]] line = ax.plot(x, y) - plt.setp( - line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7) + plt.setp(line, color=colors[len(kp_lines)], linewidth=1.0, alpha=0.7) if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh: x = [mid_shoulder[0], mid_hip[0]] y = [mid_shoulder[1], mid_hip[1]] line = ax.plot(x, y) - plt.setp( - line, color=colors[len(kp_lines) + 1], linewidth=1.0, - alpha=0.7) + plt.setp(line, color=colors[len(kp_lines) + 1], linewidth=1.0, alpha=0.7) # DensePose Visualization Starts!! ## Get full IUV image out @@ -283,14 +304,19 @@ def vis_one_image( #### output = IUV_fields[ind] #### - All_Coords_Old = All_Coords[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2], :] - All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, 0])[All_Coords_Old == 0] - All_Coords[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2], :] = All_Coords_Old + All_Coords_Old = All_Coords[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2], :] + All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, + 0])[All_Coords_Old == 0] + All_Coords[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2], :] = All_Coords_Old ### CurrentMask = (output[0, :, :] > 0).astype(np.float32) - All_inds_old = All_inds[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2]] + All_inds_old = All_inds[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2]] All_inds_old[All_inds_old == 0] = CurrentMask[All_inds_old == 0] * i - All_inds[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2]] = All_inds_old + All_inds[entry[1]:entry[1] + output.shape[1], + entry[0]:entry[0] + output.shape[2]] = All_inds_old # All_Coords[:, :, 1:3] = 255. * All_Coords[:, :, 1:3] All_Coords[All_Coords > 255] = 255. @@ -323,7 +349,7 @@ def vis_one_image( entry = boxes[ind, :] if entry[4] > 0.75: entry = entry[0:4].astype(int) - center_roi = [(entry[2]+entry[0]) / 2., (entry[3]+entry[1]) / 2.] + center_roi = [(entry[2] + entry[0]) / 2., (entry[3] + entry[1]) / 2.] #### output, center_out = smpl_fields[ind] #### @@ -345,7 +371,8 @@ def vis_one_image( # All_Coords_Old = All_Coords[entry[1]: entry[1] + output.shape[1], entry[0]:entry[0] + output.shape[2], # :] - All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, 0])[All_Coords_Old == 0] + All_Coords_Old[All_Coords_Old == 0] = output.transpose([1, 2, + 0])[All_Coords_Old == 0] All_Coords[y1_img:y2_img, x1_img:x2_img, :] = All_Coords_Old ### # CurrentMask = (output[0, :, :] > 0).astype(np.float32) @@ -376,8 +403,16 @@ def vis_one_image( plt.close('all') -def vis_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis, - file_name=None, nrow=8, padding=0, pad_value=1, add_text=True): +def vis_batch_image_with_joints( + batch_image, + batch_joints, + batch_joints_vis, + file_name=None, + nrow=8, + padding=0, + pad_value=1, + add_text=True +): ''' batch_image: [batch_size, channel, height, width] batch_joints: [batch_size, num_joints, 3], @@ -417,8 +452,10 @@ def vis_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis, else: cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 0, [0, 255, 0], -1) if add_text: - cv2.putText(ndarr, str(count), (int(joint[0]), int(joint[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, - ( 0, 255, 0), 1) + cv2.putText( + ndarr, str(count), (int(joint[0]), int(joint[1])), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1 + ) except Exception as e: print(e) k = k + 1 @@ -436,6 +473,7 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None): n_sample = max_show color = ['#00B0F0', '#00B050', '#DC6464', '#207070', '#BC4484'] + # color = ['g', 'b', 'r'] def m_l_r(idx): @@ -452,7 +490,7 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None): # ax_img = plt.subplot(n_sample, 2, i * 2 + 1) ax_img = plt.subplot(2, n_sample, i + 1) img_np = batch_img[i].cpu().numpy() - img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C + img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C ax_img.imshow(img_np) ax_img.set_axis_off() ax_pred = plt.subplot(2, n_sample, n_sample + i + 1, projection='3d') @@ -464,14 +502,29 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None): if plot_kps.shape[1] > 2: if joint_group is None: ax_pred.scatter(plot_kps[:, 2], plot_kps[:, 0], plot_kps[:, 1], s=10, marker='.') - ax_pred.scatter(plot_kps[0, 2], plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.') + ax_pred.scatter( + plot_kps[0, 2], plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.' + ) else: for j in range(len(joint_group)): - ax_pred.scatter(plot_kps[joint_group[j], 2], plot_kps[joint_group[j], 0], plot_kps[joint_group[j], 1], s=30, c=color[j], marker='s') + ax_pred.scatter( + plot_kps[joint_group[j], 2], + plot_kps[joint_group[j], 0], + plot_kps[joint_group[j], 1], + s=30, + c=color[j], + marker='s' + ) if pairs is not None: for p in pairs: - ax_pred.plot(plot_kps[p, 2], plot_kps[p, 0], plot_kps[p, 1], c=color[m_l_r(p[1])], linewidth=2) + ax_pred.plot( + plot_kps[p, 2], + plot_kps[p, 0], + plot_kps[p, 1], + c=color[m_l_r(p[1])], + linewidth=2 + ) # ax_pred.set_axis_off() @@ -483,7 +536,6 @@ def vis_img_3Djoint(batch_img, joints, pairs=None, joint_group=None): ax_pred.zaxis.set_ticks([]) - def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None): n_sample = joints.shape[0] max_show = 2 @@ -494,6 +546,7 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None): n_sample = max_show color = ['#00B0F0', '#00B050', '#DC6464', '#207070', '#BC4484'] + # color = ['g', 'b', 'r'] def m_l_r(idx): @@ -510,7 +563,7 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None): # ax_img = plt.subplot(n_sample, 2, i * 2 + 1) ax_img = plt.subplot(2, n_sample, i + 1) img_np = batch_img[i].cpu().numpy() - img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C + img_np = np.transpose(img_np, (1, 2, 0)) # H*W*C ax_img.imshow(img_np) ax_img.set_axis_off() ax_pred = plt.subplot(2, n_sample, n_sample + i + 1) @@ -526,11 +579,23 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None): # ax_pred.scatter(plot_kps[0, 0], plot_kps[0, 1], s=10, c='g', marker='.') else: for j in range(len(joint_group)): - ax_pred.scatter(plot_kps[joint_group[j], 0], plot_kps[joint_group[j], 1], s=100, c=color[j], marker='o') + ax_pred.scatter( + plot_kps[joint_group[j], 0], + plot_kps[joint_group[j], 1], + s=100, + c=color[j], + marker='o' + ) if pairs is not None: for p in pairs: - ax_pred.plot(plot_kps[p, 0], plot_kps[p, 1], c=color[m_l_r(p[1])], linestyle=':', linewidth=3) + ax_pred.plot( + plot_kps[p, 0], + plot_kps[p, 1], + c=color[m_l_r(p[1])], + linestyle=':', + linewidth=3 + ) ax_pred.set_axis_off() @@ -542,34 +607,35 @@ def vis_img_2Djoint(batch_img, joints, pairs=None, joint_group=None): ax_pred.yaxis.set_ticks([]) # ax_pred.zaxis.set_ticks([]) + def draw_skeleton(image, kp_2d, dataset='common', unnormalize=True, thickness=2): if unnormalize: - kp_2d[:,:2] = normalize_2d_kp(kp_2d[:,:2], 224, inv=True) + kp_2d[:, :2] = normalize_2d_kp(kp_2d[:, :2], 224, inv=True) - kp_2d[:,2] = kp_2d[:,2] > 0.3 + kp_2d[:, 2] = kp_2d[:, 2] > 0.3 kp_2d = np.array(kp_2d, dtype=int) rcolor = get_colors()['red'].tolist() pcolor = get_colors()['green'].tolist() lcolor = get_colors()['blue'].tolist() - common_lr = [0,0,1,1,0,0,0,0,1,0,0,1,1,1,0] - for idx,pt in enumerate(kp_2d): - if pt[2] > 0: # if visible + common_lr = [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0] + for idx, pt in enumerate(kp_2d): + if pt[2] > 0: # if visible if idx % 2 == 0: color = rcolor else: color = pcolor cv2.circle(image, (pt[0], pt[1]), 4, color, -1) # cv2.putText(image, f'{idx}', (pt[0]+1, pt[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0)) - + if dataset == 'common' and len(kp_2d) != 15: return image skeleton = eval(f'kp_utils.get_{dataset}_skeleton')() - for i,(j1,j2) in enumerate(skeleton): - if kp_2d[j1, 2] > 0 and kp_2d[j2, 2] > 0: # if visible + for i, (j1, j2) in enumerate(skeleton): + if kp_2d[j1, 2] > 0 and kp_2d[j2, 2] > 0: # if visible if dataset == 'common': color = rcolor if common_lr[i] == 0 else lcolor else: @@ -579,6 +645,7 @@ def draw_skeleton(image, kp_2d, dataset='common', unnormalize=True, thickness=2) return image + # https://stackoverflow.com/questions/13685386/matplotlib-equal-unit-length-with-equal-aspect-ratio-z-axis-is-not-equal-to def set_axes_equal(ax): '''Make axes of 3D plot have equal scale so that spheres appear as spheres, @@ -602,8 +669,8 @@ def set_axes_equal(ax): # The plot bounding box is a sphere in the sense of the infinity # norm, hence I call half the max range the plot radius. - plot_radius = 0.5*max([x_range, y_range, z_range]) + plot_radius = 0.5 * max([x_range, y_range, z_range]) ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius]) ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius]) - ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) \ No newline at end of file + ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) diff --git a/lib/smplx/body_models.py b/lib/smplx/body_models.py index bad1f542e62a4c2765e56893fd78156122409992..b98adb635a3f9296c102e2bb6ca93bcdb14ab57d 100644 --- a/lib/smplx/body_models.py +++ b/lib/smplx/body_models.py @@ -61,7 +61,7 @@ ModelOutput = namedtuple( "jaw_pose", ], ) -ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields) +ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) class SMPL(nn.Module): @@ -234,7 +234,9 @@ class SMPL(nn.Module): default_body_pose = body_pose.clone().detach() else: default_body_pose = torch.tensor(body_pose, dtype=dtype) - self.register_parameter("body_pose", nn.Parameter(default_body_pose, requires_grad=True)) + self.register_parameter( + "body_pose", nn.Parameter(default_body_pose, requires_grad=True) + ) if create_transl: if transl is None: @@ -403,7 +405,6 @@ class SMPL(nn.Module): class SMPLLayer(SMPL): - def __init__(self, *args, **kwargs) -> None: # Just create a SMPL module without any member variables super(SMPLLayer, self).__init__( @@ -465,11 +466,16 @@ class SMPLLayer(SMPL): device, dtype = self.shapedirs.device, self.shapedirs.dtype if global_orient is None: global_orient = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if body_pose is None: body_pose = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, - -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, + 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, + -1).contiguous() + ) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: @@ -630,7 +636,9 @@ class SMPLH(SMPL): self.np_left_hand_components = left_hand_components self.np_right_hand_components = right_hand_components if self.use_pca: - self.register_buffer("left_hand_components", torch.tensor(left_hand_components, dtype=dtype)) + self.register_buffer( + "left_hand_components", torch.tensor(left_hand_components, dtype=dtype) + ) self.register_buffer( "right_hand_components", torch.tensor(right_hand_components, dtype=dtype), @@ -733,7 +741,9 @@ class SMPLH(SMPL): if self.use_pca: left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) - right_hand_pose = torch.einsum("bi,ij->bj", [right_hand_pose, self.right_hand_components]) + right_hand_pose = torch.einsum( + "bi,ij->bj", [right_hand_pose, self.right_hand_components] + ) full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) @@ -775,7 +785,6 @@ class SMPLH(SMPL): class SMPLHLayer(SMPLH): - def __init__(self, *args, **kwargs) -> None: """SMPL+H as a layer model constructor""" super(SMPLHLayer, self).__init__( @@ -857,15 +866,24 @@ class SMPLHLayer(SMPLH): device, dtype = self.shapedirs.device, self.shapedirs.dtype if global_orient is None: global_orient = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if body_pose is None: - body_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous()) + body_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 21, -1, -1).contiguous() + ) if left_hand_pose is None: left_hand_pose = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + ) if right_hand_pose is None: right_hand_pose = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + ) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: @@ -926,7 +944,7 @@ class SMPLX(SMPLH): which includes joints for the neck, jaw, eyeballs and fingers. """ - NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21 + NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS # 21 NUM_HAND_JOINTS = 15 NUM_FACE_JOINTS = 3 NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS @@ -1092,7 +1110,9 @@ class SMPLX(SMPLH): if create_expression: if expression is None: - default_expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype) + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype + ) else: default_expression = torch.tensor(expression, dtype=dtype) expression_param = nn.Parameter(default_expression, requires_grad=True) @@ -1226,7 +1246,9 @@ class SMPLX(SMPLH): if self.use_pca: left_hand_pose = torch.einsum("bi,ij->bj", [left_hand_pose, self.left_hand_components]) - right_hand_pose = torch.einsum("bi,ij->bj", [right_hand_pose, self.right_hand_components]) + right_hand_pose = torch.einsum( + "bi,ij->bj", [right_hand_pose, self.right_hand_components] + ) full_pose = torch.cat( [ @@ -1315,7 +1337,9 @@ class SMPLX(SMPLH): dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) - lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1 + ) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) @@ -1350,7 +1374,6 @@ class SMPLX(SMPLH): class SMPLXLayer(SMPLX): - def __init__(self, *args, **kwargs) -> None: # Just create a SMPLX module without any member variables super(SMPLXLayer, self).__init__( @@ -1454,25 +1477,45 @@ class SMPLXLayer(SMPLX): if global_orient is None: global_orient = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if body_pose is None: body_pose = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, - -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, + 3).expand(batch_size, self.NUM_BODY_JOINTS, -1, + -1).contiguous() + ) if left_hand_pose is None: left_hand_pose = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + ) if right_hand_pose is None: right_hand_pose = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + ) if jaw_pose is None: - jaw_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + jaw_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if leye_pose is None: - leye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + leye_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if reye_pose is None: - reye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + reye_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if expression is None: - expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype, device=device) + expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype, device=device + ) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: @@ -1521,7 +1564,9 @@ class SMPLXLayer(SMPLX): dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) - lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1 + ) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) @@ -1646,7 +1691,9 @@ class MANO(SMPL): ) # add only MANO tips to the extra joints - self.vertex_joint_selector.extra_joints_idxs = to_tensor(list(VERTEX_IDS["mano"].values()), dtype=torch.long) + self.vertex_joint_selector.extra_joints_idxs = to_tensor( + list(VERTEX_IDS["mano"].values()), dtype=torch.long + ) self.use_pca = use_pca self.num_pca_comps = num_pca_comps @@ -1765,7 +1812,6 @@ class MANO(SMPL): class MANOLayer(MANO): - def __init__(self, *args, **kwargs) -> None: """MANO as a layer model constructor""" super(MANOLayer, self).__init__( @@ -1795,11 +1841,16 @@ class MANOLayer(MANO): if global_orient is None: batch_size = 1 global_orient = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) else: batch_size = global_orient.shape[0] if hand_pose is None: - hand_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous()) + hand_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, -1).contiguous() + ) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if transl is None: @@ -1993,7 +2044,9 @@ class FLAME(SMPL): if create_expression: if expression is None: - default_expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype) + default_expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype + ) else: default_expression = torch.tensor(expression, dtype=dtype) expression_param = nn.Parameter(default_expression, requires_grad=True) @@ -2012,7 +2065,8 @@ class FLAME(SMPL): self.register_buffer("lmk_bary_coords", torch.tensor(lmk_bary_coords, dtype=dtype)) if self.use_face_contour: face_contour_path = os.path.join(model_path, "flame_dynamic_embedding.npy") - contour_embeddings = np.load(face_contour_path, allow_pickle=True, encoding="latin1")[()] + contour_embeddings = np.load(face_contour_path, allow_pickle=True, + encoding="latin1")[()] dynamic_lmk_faces_idx = np.array(contour_embeddings["lmk_face_idx"], dtype=np.int64) dynamic_lmk_faces_idx = torch.tensor(dynamic_lmk_faces_idx, dtype=torch.long) @@ -2148,7 +2202,9 @@ class FLAME(SMPL): ) dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) - lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1 + ) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) @@ -2179,7 +2235,6 @@ class FLAME(SMPL): class FLAMELayer(FLAME): - def __init__(self, *args, **kwargs) -> None: """ FLAME as a layer model constructor """ super(FLAMELayer, self).__init__( @@ -2248,21 +2303,37 @@ class FLAMELayer(FLAME): if global_orient is None: batch_size = 1 global_orient = ( - torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) else: batch_size = global_orient.shape[0] if neck_pose is None: - neck_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous()) + neck_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 1, -1, -1).contiguous() + ) if jaw_pose is None: - jaw_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + jaw_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if leye_pose is None: - leye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + leye_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if reye_pose is None: - reye_pose = (torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()) + reye_pose = ( + torch.eye(3, device=device, + dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + ) if betas is None: betas = torch.zeros([batch_size, self.num_betas], dtype=dtype, device=device) if expression is None: - expression = torch.zeros([batch_size, self.num_expression_coeffs], dtype=dtype, device=device) + expression = torch.zeros( + [batch_size, self.num_expression_coeffs], dtype=dtype, device=device + ) if transl is None: transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) @@ -2296,7 +2367,9 @@ class FLAMELayer(FLAME): ) dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) - lmk_bary_coords = torch.cat([lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) + lmk_bary_coords = torch.cat( + [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1 + ) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) @@ -2391,7 +2464,9 @@ def build_layer(model_path: str, raise ValueError(f"Unknown model type {model_type}, exiting!") -def create(model_path: str, model_type: str = "smpl", **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: +def create(model_path: str, + model_type: str = "smpl", + **kwargs) -> Union[SMPL, SMPLH, SMPLX, MANO, FLAME]: """Method for creating a model from a path and a model type Parameters diff --git a/lib/smplx/joint_names.py b/lib/smplx/joint_names.py index eadde0139deb4a62345d53cc1a8eb151bf83c1b5..a4f7cb0a3d2f9712de47e32e23ae36f918be6abc 100644 --- a/lib/smplx/joint_names.py +++ b/lib/smplx/joint_names.py @@ -129,8 +129,8 @@ JOINT_NAMES = [ "left_mouth_3", "left_mouth_2", "left_mouth_1", - "left_mouth_5", # 59 in OpenPose output - "left_mouth_4", # 58 in OpenPose output + "left_mouth_5", # 59 in OpenPose output + "left_mouth_4", # 58 in OpenPose output "mouth_bottom", "right_mouth_4", "right_mouth_5", diff --git a/lib/smplx/lbs.py b/lib/smplx/lbs.py index c74f480fd146db9f70e92a20baec2543a7c30ca2..ac64f4b41be569331d632bfeb50fef9c50dc3d71 100644 --- a/lib/smplx/lbs.py +++ b/lib/smplx/lbs.py @@ -79,11 +79,15 @@ def find_dynamic_lmk_idx_and_bcoords( else: rot_mats = torch.index_select(pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain) - rel_rot_mat = (torch.eye(3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)) + rel_rot_mat = ( + torch.eye(3, device=vertices.device, + dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1) + ) for idx in range(len(neck_kin_chain)): rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) - y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, max=39)).to(dtype=torch.long) + y_rot_angle = torch.round(torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) mask = y_rot_angle.lt(-39).to(dtype=torch.long) neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) @@ -95,7 +99,9 @@ def find_dynamic_lmk_idx_and_bcoords( return dyn_lmk_faces_idx, dyn_lmk_b_coords -def vertices2landmarks(vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, lmk_bary_coords: Tensor) -> Tensor: +def vertices2landmarks( + vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, lmk_bary_coords: Tensor +) -> Tensor: """Calculates landmarks by barycentric interpolation Parameters @@ -123,7 +129,9 @@ def vertices2landmarks(vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, l lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(batch_size, -1, 3) - lmk_faces += (torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts) + lmk_faces += ( + torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts + ) lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) @@ -205,7 +213,8 @@ def lbs( pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident rot_mats = pose.view(batch_size, -1, 3, 3) - pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(batch_size, -1, 3) + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) v_posed = pose_offsets + v_shaped # 4. Get the global joint location @@ -292,7 +301,8 @@ def general_lbs( else: rot_mats = pose.view(batch_size, -1, 3, 3) pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident - pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(batch_size, -1, 3) + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) v_posed = pose_offsets + v_template @@ -407,7 +417,9 @@ def transform_mat(R: Tensor, t: Tensor) -> Tensor: return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) -def batch_rigid_transform(rot_mats: Tensor, joints: Tensor, parents: Tensor, dtype=torch.float32) -> Tensor: +def batch_rigid_transform( + rot_mats: Tensor, joints: Tensor, parents: Tensor, dtype=torch.float32 +) -> Tensor: """ Applies a batch of rigid transformations to the joints @@ -436,7 +448,8 @@ def batch_rigid_transform(rot_mats: Tensor, joints: Tensor, parents: Tensor, dty rel_joints = joints.clone() rel_joints[:, 1:] -= joints[:, parents[1:]] - transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) + transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) transform_chain = [transforms_mat[:, 0]] for i in range(1, parents.shape[0]): @@ -452,6 +465,8 @@ def batch_rigid_transform(rot_mats: Tensor, joints: Tensor, parents: Tensor, dty joints_homogen = F.pad(joints, [0, 0, 0, 1]) - rel_transforms = transforms - F.pad(torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0] + ) return posed_joints, rel_transforms diff --git a/lib/smplx/utils.py b/lib/smplx/utils.py index 6deb46a514bf97ed74ab5314b93b1148b1f376f8..d43a25217573f4c327adbf0411a76d1081632a69 100644 --- a/lib/smplx/utils.py +++ b/lib/smplx/utils.py @@ -105,7 +105,6 @@ def to_tensor(array: Union[Array, Tensor], dtype=torch.float32) -> Tensor: class Struct(object): - def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) @@ -121,6 +120,5 @@ def rot_mat_to_euler(rot_mats): # Calculates rotation matrix to euler angles # Careful for extreme cases of eular angles like [0.0, pi, 0.0] - sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + - rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) return torch.atan2(-rot_mats[:, 2, 0], sy) diff --git a/lib/smplx/vertex_ids.py b/lib/smplx/vertex_ids.py index b45c0b84d5462f40adf776e6b615e8fb07f7be26..31ed146ed4b3529bfbe0c92450bd3b02559f338b 100644 --- a/lib/smplx/vertex_ids.py +++ b/lib/smplx/vertex_ids.py @@ -21,52 +21,54 @@ from __future__ import division # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to # MSCOCO and OpenPose joints vertex_ids = { - "smplh": { - "nose": 332, - "reye": 6260, - "leye": 2800, - "rear": 4071, - "lear": 583, - "rthumb": 6191, - "rindex": 5782, - "rmiddle": 5905, - "rring": 6016, - "rpinky": 6133, - "lthumb": 2746, - "lindex": 2319, - "lmiddle": 2445, - "lring": 2556, - "lpinky": 2673, - "LBigToe": 3216, - "LSmallToe": 3226, - "LHeel": 3387, - "RBigToe": 6617, - "RSmallToe": 6624, - "RHeel": 6787, - }, - "smplx": { - "nose": 9120, - "reye": 9929, - "leye": 9448, - "rear": 616, - "lear": 6, - "rthumb": 8079, - "rindex": 7669, - "rmiddle": 7794, - "rring": 7905, - "rpinky": 8022, - "lthumb": 5361, - "lindex": 4933, - "lmiddle": 5058, - "lring": 5169, - "lpinky": 5286, - "LBigToe": 5770, - "LSmallToe": 5780, - "LHeel": 8846, - "RBigToe": 8463, - "RSmallToe": 8474, - "RHeel": 8635, - }, + "smplh": + { + "nose": 332, + "reye": 6260, + "leye": 2800, + "rear": 4071, + "lear": 583, + "rthumb": 6191, + "rindex": 5782, + "rmiddle": 5905, + "rring": 6016, + "rpinky": 6133, + "lthumb": 2746, + "lindex": 2319, + "lmiddle": 2445, + "lring": 2556, + "lpinky": 2673, + "LBigToe": 3216, + "LSmallToe": 3226, + "LHeel": 3387, + "RBigToe": 6617, + "RSmallToe": 6624, + "RHeel": 6787, + }, + "smplx": + { + "nose": 9120, + "reye": 9929, + "leye": 9448, + "rear": 616, + "lear": 6, + "rthumb": 8079, + "rindex": 7669, + "rmiddle": 7794, + "rring": 7905, + "rpinky": 8022, + "lthumb": 5361, + "lindex": 4933, + "lmiddle": 5058, + "lring": 5169, + "lpinky": 5286, + "LBigToe": 5770, + "LSmallToe": 5780, + "LHeel": 8846, + "RBigToe": 8463, + "RSmallToe": 8474, + "RHeel": 8635, + }, "mano": { "thumb": 744, "index": 320, diff --git a/lib/smplx/vertex_joint_selector.py b/lib/smplx/vertex_joint_selector.py index facf2afe433fde7f63a9978caa0258a7a38a30f3..1680e07acb03402a54fc0621ab36ec1d4de2c78e 100644 --- a/lib/smplx/vertex_joint_selector.py +++ b/lib/smplx/vertex_joint_selector.py @@ -27,12 +27,7 @@ from .utils import to_tensor class VertexJointSelector(nn.Module): - - def __init__(self, - vertex_ids=None, - use_hands=True, - use_feet_keypoints=True, - **kwargs): + def __init__(self, vertex_ids=None, use_hands=True, use_feet_keypoints=True, **kwargs): super(VertexJointSelector, self).__init__() extra_joints_idxs = [] @@ -63,8 +58,7 @@ class VertexJointSelector(nn.Module): dtype=np.int32, ) - extra_joints_idxs = np.concatenate( - [extra_joints_idxs, feet_keyp_idxs]) + extra_joints_idxs = np.concatenate([extra_joints_idxs, feet_keyp_idxs]) if use_hands: self.tip_names = ["thumb", "index", "middle", "ring", "pinky"] @@ -76,8 +70,7 @@ class VertexJointSelector(nn.Module): extra_joints_idxs = np.concatenate([extra_joints_idxs, tips_idxs]) - self.register_buffer("extra_joints_idxs", - to_tensor(extra_joints_idxs, dtype=torch.long)) + self.register_buffer("extra_joints_idxs", to_tensor(extra_joints_idxs, dtype=torch.long)) def forward(self, vertices, joints): extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs) diff --git a/lib/torch_utils/custom_ops.py b/lib/torch_utils/custom_ops.py index 4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e..2170f4732aba52f614b7cec09ac62465275ad90b 100644 --- a/lib/torch_utils/custom_ops.py +++ b/lib/torch_utils/custom_ops.py @@ -20,11 +20,12 @@ from torch.utils.file_baton import FileBaton #---------------------------------------------------------------------------- # Global options. -verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' #---------------------------------------------------------------------------- # Internal helper funcs. + def _find_compiler_bindir(): patterns = [ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', @@ -38,11 +39,13 @@ def _find_compiler_bindir(): return matches[-1] return None + #---------------------------------------------------------------------------- # Main entry point for compiling and loading C++/CUDA plugins. _cached_plugins = dict() + def get_plugin(module_name, sources, **build_kwargs): assert verbosity in ['none', 'brief', 'full'] @@ -56,12 +59,14 @@ def get_plugin(module_name, sources, **build_kwargs): elif verbosity == 'brief': print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) - try: # pylint: disable=too-many-nested-blocks + try: # pylint: disable=too-many-nested-blocks # Make sure we can find the necessary compiler binaries. if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: compiler_bindir = _find_compiler_bindir() if compiler_bindir is None: - raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + raise RuntimeError( + f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".' + ) os.environ['PATH'] += ';' + compiler_bindir # Compile and load. @@ -79,7 +84,9 @@ def get_plugin(module_name, sources, **build_kwargs): # actually cares about this.) source_dirs_set = set(os.path.dirname(source) for source in sources) if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): - all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + all_source_files = sorted( + list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()) + ) # Compute a combined hash digest for all source files in the same # custom op directory (usually .cu, .cpp, .py and .h files). @@ -87,7 +94,9 @@ def get_plugin(module_name, sources, **build_kwargs): for src in all_source_files: with open(src, 'rb') as f: hash_md5.update(f.read()) - build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + build_dir = torch.utils.cpp_extension._get_build_directory( + module_name, verbose=verbose_build + ) # pylint: disable=protected-access digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) if not os.path.isdir(digest_build_dir): @@ -96,7 +105,9 @@ def get_plugin(module_name, sources, **build_kwargs): if baton.try_acquire(): try: for src in all_source_files: - shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + shutil.copyfile( + src, os.path.join(digest_build_dir, os.path.basename(src)) + ) finally: baton.release() else: @@ -104,10 +115,17 @@ def get_plugin(module_name, sources, **build_kwargs): # wait until done and continue. baton.wait() digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] - torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, - verbose=verbose_build, sources=digest_sources, **build_kwargs) + torch.utils.cpp_extension.load( + name=module_name, + build_directory=build_dir, + verbose=verbose_build, + sources=digest_sources, + **build_kwargs + ) else: - torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + torch.utils.cpp_extension.load( + name=module_name, verbose=verbose_build, sources=sources, **build_kwargs + ) module = importlib.import_module(module_name) except: @@ -123,4 +141,5 @@ def get_plugin(module_name, sources, **build_kwargs): _cached_plugins[module_name] = module return module + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/misc.py b/lib/torch_utils/misc.py index 7829f4d9f168557ce8a9a6dec289aa964234cb8c..61c266a84d83e9a486df52e725af1c51488951e4 100644 --- a/lib/torch_utils/misc.py +++ b/lib/torch_utils/misc.py @@ -19,6 +19,7 @@ import dnnlib _constant_cache = dict() + def constant(value, shape=None, dtype=None, device=None, memory_format=None): value = np.asarray(value) if shape is not None: @@ -40,13 +41,15 @@ def constant(value, shape=None, dtype=None, device=None, memory_format=None): _constant_cache[key] = tensor return tensor + #---------------------------------------------------------------------------- # Replace NaN/Inf with specified numerical values. try: - nan_to_num = torch.nan_to_num # 1.8.0a0 + nan_to_num = torch.nan_to_num # 1.8.0a0 except AttributeError: - def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin assert isinstance(input, torch.Tensor) if posinf is None: posinf = torch.finfo(input.dtype).max @@ -55,57 +58,73 @@ except AttributeError: assert nan == 0 return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + #---------------------------------------------------------------------------- # Symbolic assert. try: - symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access except AttributeError: - symbolic_assert = torch.Assert # 1.7.0 + symbolic_assert = torch.Assert # 1.7.0 #---------------------------------------------------------------------------- # Context manager to suppress known warnings in torch.jit.trace(). + class suppress_tracer_warnings(warnings.catch_warnings): def __enter__(self): super().__enter__() warnings.simplefilter('ignore', category=torch.jit.TracerWarning) return self + #---------------------------------------------------------------------------- # Assert that the shape of a tensor matches the given list of integers. # None indicates that the size of a dimension is allowed to vary. # Performs symbolic assertion when used in torch.jit.trace(). + def assert_shape(tensor, ref_shape): if tensor.ndim != len(ref_shape): - raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + raise AssertionError( + f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}' + ) for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): if ref_size is None: pass elif isinstance(ref_size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}' + ) elif isinstance(size, torch.Tensor): - with suppress_tracer_warnings(): # as_tensor results are registered as constants - symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert( + torch.equal(size, torch.as_tensor(ref_size)), + f'Wrong size for dimension {idx}: expected {ref_size}' + ) elif size != ref_size: raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + #---------------------------------------------------------------------------- # Function decorator that calls torch.autograd.profiler.record_function(). + def profiled_function(fn): def decorator(*args, **kwargs): with torch.autograd.profiler.record_function(fn.__name__): return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ return decorator + #---------------------------------------------------------------------------- # Sampler for torch.utils.data.DataLoader that loops over the dataset # indefinitely, shuffling items as it goes. + class InfiniteSampler(torch.utils.data.Sampler): def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): assert len(dataset) > 0 @@ -139,17 +158,21 @@ class InfiniteSampler(torch.utils.data.Sampler): order[i], order[j] = order[j], order[i] idx += 1 + #---------------------------------------------------------------------------- # Utilities for operating with torch.nn.Module parameters and buffers. + def params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.parameters()) + list(module.buffers()) + def named_params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.named_parameters()) + list(module.named_buffers()) + def copy_params_and_buffers(src_module, dst_module, require_all=False): assert isinstance(src_module, torch.nn.Module) assert isinstance(dst_module, torch.nn.Module) @@ -159,10 +182,12 @@ def copy_params_and_buffers(src_module, dst_module, require_all=False): if name in src_tensors: tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + #---------------------------------------------------------------------------- # Context manager for easily enabling/disabling DistributedDataParallel # synchronization. + @contextlib.contextmanager def ddp_sync(module, sync): assert isinstance(module, torch.nn.Module) @@ -172,9 +197,11 @@ def ddp_sync(module, sync): with module.no_sync(): yield + #---------------------------------------------------------------------------- # Check DistributedDataParallel consistency across processes. + def check_ddp_consistency(module, ignore_regex=None): assert isinstance(module, torch.nn.Module) for name, tensor in named_params_and_buffers(module): @@ -186,9 +213,11 @@ def check_ddp_consistency(module, ignore_regex=None): torch.distributed.broadcast(tensor=other, src=0) assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname + #---------------------------------------------------------------------------- # Print summary table of module hierarchy. + def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): assert isinstance(module, torch.nn.Module) assert not isinstance(module, torch.jit.ScriptModule) @@ -197,14 +226,17 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): # Register hooks. entries = [] nesting = [0] + def pre_hook(_mod, _inputs): nesting[0] += 1 + def post_hook(mod, _inputs, outputs): nesting[0] -= 1 if nesting[0] <= max_nesting: outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] outputs = [t for t in outputs if isinstance(t, torch.Tensor)] entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] @@ -223,7 +255,10 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): # Filter out redundant entries. if skip_redundant: - entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + entries = [ + e for e in entries + if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) + ] # Construct table. rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] @@ -237,13 +272,15 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): buffer_size = sum(t.numel() for t in e.unique_buffers) output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] - rows += [[ - name + (':0' if len(e.outputs) >= 2 else ''), - str(param_size) if param_size else '-', - str(buffer_size) if buffer_size else '-', - (output_shapes + ['-'])[0], - (output_dtypes + ['-'])[0], - ]] + rows += [ + [ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ] + ] for idx in range(1, len(e.outputs)): rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] param_total += param_size @@ -259,4 +296,5 @@ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): print() return outputs + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/ops/bias_act.py b/lib/torch_utils/ops/bias_act.py index 4bcb409a89ccf6c6f6ecfca5962683df2d280b1f..d8cfdb65d25ed077827862bc70e860c450fe929a 100644 --- a/lib/torch_utils/ops/bias_act.py +++ b/lib/torch_utils/ops/bias_act.py @@ -5,7 +5,6 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """Custom PyTorch ops for efficient bias and activation.""" import os @@ -21,15 +20,82 @@ from .. import misc #---------------------------------------------------------------------------- activation_funcs = { - 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), - 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), - 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), - 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), - 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), - 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), - 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), - 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), - 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), + 'linear': + dnnlib.EasyDict( + func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False + ), + 'relu': + dnnlib.EasyDict( + func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=2, + ref='y', + has_2nd_grad=False + ), + 'lrelu': + dnnlib.EasyDict( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, + def_gain=np.sqrt(2), + cuda_idx=3, + ref='y', + has_2nd_grad=False + ), + 'tanh': + dnnlib.EasyDict( + func=lambda x, **_: torch.tanh(x), + def_alpha=0, + def_gain=1, + cuda_idx=4, + ref='y', + has_2nd_grad=True + ), + 'sigmoid': + dnnlib.EasyDict( + func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, + def_gain=1, + cuda_idx=5, + ref='y', + has_2nd_grad=True + ), + 'elu': + dnnlib.EasyDict( + func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, + def_gain=1, + cuda_idx=6, + ref='y', + has_2nd_grad=True + ), + 'selu': + dnnlib.EasyDict( + func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, + def_gain=1, + cuda_idx=7, + ref='y', + has_2nd_grad=True + ), + 'softplus': + dnnlib.EasyDict( + func=lambda x, **_: torch.nn.functional.softplus(x), + def_alpha=0, + def_gain=1, + cuda_idx=8, + ref='y', + has_2nd_grad=True + ), + 'swish': + dnnlib.EasyDict( + func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=9, + ref='x', + has_2nd_grad=True + ), } #---------------------------------------------------------------------------- @@ -38,6 +104,7 @@ _inited = False _plugin = None _null_tensor = torch.empty([0]) + def _init(): global _inited, _plugin if not _inited: @@ -45,13 +112,20 @@ def _init(): sources = ['bias_act.cpp', 'bias_act.cu'] sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] try: - _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + _plugin = custom_ops.get_plugin( + 'bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'] + ) except: - warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + warnings.warn( + 'Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + + traceback.format_exc() + ) return _plugin is not None + #---------------------------------------------------------------------------- + def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): r"""Fused bias and activation function. @@ -88,8 +162,10 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + #---------------------------------------------------------------------------- + @misc.profiled_function def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): """Slow reference implementation of `bias_act()` using standard TensorFlow ops. @@ -119,13 +195,15 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N # Clamp. if clamp >= 0: - x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type return x + #---------------------------------------------------------------------------- _bias_act_cuda_cache = dict() + def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): """Fast CUDA implementation of `bias_act()` using custom ops. """ @@ -144,21 +222,26 @@ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): # Forward op. class BiasActCuda(torch.autograd.Function): @staticmethod - def forward(ctx, x, b): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride( + )[1] == 1 else torch.contiguous_format x = x.contiguous(memory_format=ctx.memory_format) b = b.contiguous() if b is not None else _null_tensor y = x if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: - y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + y = _plugin.bias_act( + x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, + gain, clamp + ) ctx.save_for_backward( x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, - y if 'y' in spec.ref else _null_tensor) + y if 'y' in spec.ref else _null_tensor + ) return y @staticmethod - def backward(ctx, dy): # pylint: disable=arguments-differ + def backward(ctx, dy): # pylint: disable=arguments-differ dy = dy.contiguous(memory_format=ctx.memory_format) x, b, y = ctx.saved_tensors dx = None @@ -177,16 +260,17 @@ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): # Backward op. class BiasActCudaGrad(torch.autograd.Function): @staticmethod - def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ - ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format - dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) - ctx.save_for_backward( - dy if spec.has_2nd_grad else _null_tensor, - x, b, y) + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride( + )[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act( + dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp + ) + ctx.save_for_backward(dy if spec.has_2nd_grad else _null_tensor, x, b, y) return dx @staticmethod - def backward(ctx, d_dx): # pylint: disable=arguments-differ + def backward(ctx, d_dx): # pylint: disable=arguments-differ d_dx = d_dx.contiguous(memory_format=ctx.memory_format) dy, x, b, y = ctx.saved_tensors d_dy = None @@ -209,4 +293,5 @@ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): _bias_act_cuda_cache[key] = BiasActCuda return BiasActCuda + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/ops/conv2d_gradfix.py b/lib/torch_utils/ops/conv2d_gradfix.py index e95e10d0b1d0315a63a76446fd4c5c293c8bbc6d..29c3d8f5a8a1e2816e225af3157fc1bb99a4fd33 100644 --- a/lib/torch_utils/ops/conv2d_gradfix.py +++ b/lib/torch_utils/ops/conv2d_gradfix.py @@ -5,7 +5,6 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """Custom replacement for `torch.nn.functional.conv2d` that supports arbitrarily high order gradients with zero performance penalty.""" @@ -19,8 +18,9 @@ import torch #---------------------------------------------------------------------------- -enabled = False # Enable the custom op by setting this to true. -weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + @contextlib.contextmanager def no_weight_gradients(): @@ -30,20 +30,60 @@ def no_weight_gradients(): yield weight_gradients_disabled = old + #---------------------------------------------------------------------------- + def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) - return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - -def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + return _conv2d_gradfix( + transpose=False, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=0, + dilation=dilation, + groups=groups + ).apply(input, weight, bias) + return torch.nn.functional.conv2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups + ) + + +def conv_transpose2d( + input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1 +): if _should_use_custom_op(input): - return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) - return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + return _conv2d_gradfix( + transpose=True, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation + ).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation + ) + #---------------------------------------------------------------------------- + def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): @@ -52,19 +92,24 @@ def _should_use_custom_op(input): return False if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): return True - warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + warnings.warn( + f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().' + ) return False + def _tuple_of_ints(xs, ndim): - xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim assert len(xs) == ndim assert all(isinstance(x, int) for x in xs) return xs + #---------------------------------------------------------------------------- _conv2d_gradfix_cache = dict() + def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): # Parse arguments. ndim = 2 @@ -87,20 +132,18 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di assert all(dilation[i] >= 0 for i in range(ndim)) if not transpose: assert all(output_padding[i] == 0 for i in range(ndim)) - else: # transpose + else: # transpose assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) # Helpers. common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): if transpose: return [0, 0] return [ - input_shape[i + 2] - - (output_shape[i + 2] - 1) * stride[i] - - (1 - 2 * padding[i]) - - dilation[i] * (weight_shape[i + 2] - 1) - for i in range(ndim) + input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - (1 - 2 * padding[i]) - + dilation[i] * (weight_shape[i + 2] - 1) for i in range(ndim) ] # Forward & backward. @@ -109,9 +152,17 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di def forward(ctx, input, weight, bias): assert weight.shape == weight_shape if not transpose: - output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) - else: # transpose - output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + output = torch.nn.functional.conv2d( + input=input, weight=weight, bias=bias, **common_kwargs + ) + else: # transpose + output = torch.nn.functional.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + output_padding=output_padding, + **common_kwargs + ) ctx.save_for_backward(input, weight) return output @@ -124,7 +175,12 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di if ctx.needs_input_grad[0]: p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + grad_input = _conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs + ).apply(grad_output, weight, None) assert grad_input.shape == input.shape if ctx.needs_input_grad[1] and not weight_gradients_disabled: @@ -140,9 +196,17 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di class Conv2dGradWeight(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input): - op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') - flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] - grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + op = torch._C._jit_get_operation( + 'aten::cudnn_convolution_backward_weight' + if not transpose else 'aten::cudnn_convolution_transpose_backward_weight' + ) + flags = [ + torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, + torch.backends.cudnn.allow_tf32 + ] + grad_weight = op( + weight_shape, grad_output, input, padding, stride, dilation, groups, *flags + ) assert grad_weight.shape == weight_shape ctx.save_for_backward(grad_output, input) return grad_weight @@ -159,7 +223,12 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di if ctx.needs_input_grad[1]: p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) - grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + grad2_input = _conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs + ).apply(grad_output, grad2_grad_weight, None) assert grad2_input.shape == input.shape return grad2_grad_output, grad2_input @@ -167,4 +236,5 @@ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, di _conv2d_gradfix_cache[key] = Conv2d return Conv2d + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/ops/conv2d_resample.py b/lib/torch_utils/ops/conv2d_resample.py index cd4750744c83354bab78704d4ef51ad1070fcc4a..9f347c59165d1aceafee936b36281610b5a64e1b 100644 --- a/lib/torch_utils/ops/conv2d_resample.py +++ b/lib/torch_utils/ops/conv2d_resample.py @@ -5,7 +5,6 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """2D convolution with optional up/downsampling.""" import torch @@ -18,21 +17,24 @@ from .upfirdn2d import _get_filter_size #---------------------------------------------------------------------------- + def _get_weight_shape(w): - with misc.suppress_tracer_warnings(): # this value will be treated as a constant + with misc.suppress_tracer_warnings(): # this value will be treated as a constant shape = [int(sz) for sz in w.shape] misc.assert_shape(w, shape) return shape + #---------------------------------------------------------------------------- + def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. """ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) # Flip weight if requested. - if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). w = w.flip([2, 3]) # Workaround performance pitfall in cuDNN 8.0.5, triggered when using @@ -53,10 +55,14 @@ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_w op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d return op(x, w, stride=stride, padding=padding, groups=groups) + #---------------------------------------------------------------------------- + @misc.profiled_function -def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): +def conv2d_resample( + x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False +): r"""2D convolution with optional up/downsampling. Padding is performed only once at the beginning, not between the operations. @@ -83,7 +89,9 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight # Validate arguments. assert isinstance(x, torch.Tensor) and (x.ndim == 4) assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) - assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert f is None or ( + isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32 + ) assert isinstance(up, int) and (up >= 1) assert isinstance(down, int) and (down >= 1) assert isinstance(groups, int) and (groups >= 1) @@ -105,19 +113,23 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. if kw == 1 and kh == 1 and (down > 1 and up == 1): - x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = upfirdn2d.upfirdn2d( + x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter + ) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) return x # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. if kw == 1 and kh == 1 and (up > 1 and down == 1): x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = upfirdn2d.upfirdn2d( + x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter + ) return x # Fast path: downsampling only => use strided convolution. if down > 1 and up == 1: - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) return x @@ -135,8 +147,22 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight py1 -= kh - up pxt = max(min(-px0, -px1), 0) pyt = max(min(-py0, -py1), 0) - x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) - x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper( + x=x, + w=w, + stride=up, + padding=[pyt, pxt], + groups=groups, + transpose=True, + flip_weight=(not flip_weight) + ) + x = upfirdn2d.upfirdn2d( + x=x, + f=f, + padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], + gain=up**2, + flip_filter=flip_filter + ) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x @@ -144,13 +170,23 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. if up == 1 and down == 1: if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: - return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + return _conv2d_wrapper( + x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight + ) # Fallback: Generic reference implementation. - x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = upfirdn2d.upfirdn2d( + x=x, + f=(f if up > 1 else None), + up=up, + padding=[px0, px1, py0, py1], + gain=up**2, + flip_filter=flip_filter + ) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/ops/fma.py b/lib/torch_utils/ops/fma.py index 2eeac58a626c49231e04122b93e321ada954c5d3..5c030932fb439b4dcc7b08ad55d0fa2aa9d8f82f 100644 --- a/lib/torch_utils/ops/fma.py +++ b/lib/torch_utils/ops/fma.py @@ -5,28 +5,30 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" import torch #---------------------------------------------------------------------------- -def fma(a, b, c): # => a * b + c + +def fma(a, b, c): # => a * b + c return _FusedMultiplyAdd.apply(a, b, c) + #---------------------------------------------------------------------------- -class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c @staticmethod - def forward(ctx, a, b, c): # pylint: disable=arguments-differ + def forward(ctx, a, b, c): # pylint: disable=arguments-differ out = torch.addcmul(c, a, b) ctx.save_for_backward(a, b) ctx.c_shape = c.shape return out @staticmethod - def backward(ctx, dout): # pylint: disable=arguments-differ + def backward(ctx, dout): # pylint: disable=arguments-differ a, b = ctx.saved_tensors c_shape = ctx.c_shape da = None @@ -44,17 +46,23 @@ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c return da, db, dc + #---------------------------------------------------------------------------- + def _unbroadcast(x, shape): extra_dims = x.ndim - len(shape) assert extra_dims >= 0 - dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + dim = [ + i + for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) + ] if len(dim): x = x.sum(dim=dim, keepdim=True) if extra_dims: - x = x.reshape(-1, *x.shape[extra_dims+1:]) + x = x.reshape(-1, *x.shape[extra_dims + 1:]) assert x.shape == shape return x + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/ops/fused_act.py b/lib/torch_utils/ops/fused_act.py index 394a8c57229e47243ad645bc8be54674871650f6..c38a2aa0c94f033f7ebcd01eddf8da126fd7add8 100644 --- a/lib/torch_utils/ops/fused_act.py +++ b/lib/torch_utils/ops/fused_act.py @@ -36,8 +36,10 @@ class FusedLeakyReLUFunctionBackward(Function): @staticmethod def backward(ctx, gradgrad_input, gradgrad_bias): - (out,) = ctx.saved_tensors - gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale) + (out, ) = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) return gradgrad_out, None, None, None, None @@ -65,7 +67,7 @@ class FusedLeakyReLUFunction(Function): @staticmethod def backward(ctx, grad_output): - (out,) = ctx.saved_tensors + (out, ) = ctx.saved_tensors grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale @@ -78,7 +80,7 @@ class FusedLeakyReLUFunction(Function): class FusedLeakyReLU(nn.Module): - def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5): super().__init__() if bias: @@ -93,11 +95,13 @@ class FusedLeakyReLU(nn.Module): return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) -def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): if input.device.type == "cpu": if bias is not None: rest_dim = [1] * (input.ndim - bias.ndim - 1) - return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale + return F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) * scale else: return F.leaky_relu(input, negative_slope=0.2) * scale diff --git a/lib/torch_utils/ops/grid_sample_gradfix.py b/lib/torch_utils/ops/grid_sample_gradfix.py index ca6b3413ea72a734703c34382c023b84523601fd..850feacd5a6300b85493cd7f713bffab1af70536 100644 --- a/lib/torch_utils/ops/grid_sample_gradfix.py +++ b/lib/torch_utils/ops/grid_sample_gradfix.py @@ -5,7 +5,6 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """Custom replacement for `torch.nn.functional.grid_sample` that supports arbitrarily high order gradients between the input and output. Only works on 2D images and assumes @@ -20,33 +19,44 @@ import torch #---------------------------------------------------------------------------- -enabled = False # Enable the custom op by setting this to true. +enabled = False # Enable the custom op by setting this to true. #---------------------------------------------------------------------------- + def grid_sample(input, grid): if _should_use_custom_op(): return _GridSample2dForward.apply(input, grid) - return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + return torch.nn.functional.grid_sample( + input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False + ) + #---------------------------------------------------------------------------- + def _should_use_custom_op(): if not enabled: return False if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): return True - warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') + warnings.warn( + f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().' + ) return False + #---------------------------------------------------------------------------- + class _GridSample2dForward(torch.autograd.Function): @staticmethod def forward(ctx, input, grid): assert input.ndim == 4 assert grid.ndim == 4 - output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + output = torch.nn.functional.grid_sample( + input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False + ) ctx.save_for_backward(input, grid) return output @@ -56,8 +66,10 @@ class _GridSample2dForward(torch.autograd.Function): grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) return grad_input, grad_grid + #---------------------------------------------------------------------------- + class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): @@ -68,7 +80,7 @@ class _GridSample2dBackward(torch.autograd.Function): @staticmethod def backward(ctx, grad2_grad_input, grad2_grad_grid): - _ = grad2_grad_grid # unused + _ = grad2_grad_grid # unused grid, = ctx.saved_tensors grad2_grad_output = None grad2_input = None @@ -80,4 +92,5 @@ class _GridSample2dBackward(torch.autograd.Function): assert not ctx.needs_input_grad[2] return grad2_grad_output, grad2_input, grad2_grid + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/ops/native_ops.py b/lib/torch_utils/ops/native_ops.py index 09cc5c3245113c690ae7f4891f512351cfdd5187..a21a1368c69aee0e802fa710d34a59ec63523fb6 100644 --- a/lib/torch_utils/ops/native_ops.py +++ b/lib/torch_utils/ops/native_ops.py @@ -4,7 +4,7 @@ from torch.nn import functional as F class FusedLeakyReLU(nn.Module): - def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5): super().__init__() if bias: @@ -20,13 +20,15 @@ class FusedLeakyReLU(nn.Module): return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) -def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): if input.dtype == torch.float16: bias = bias.half() if bias is not None: rest_dim = [1] * (input.ndim - bias.ndim - 1) - return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale + return F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) * scale else: return F.leaky_relu(input, negative_slope=0.2) * scale @@ -48,12 +50,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] + out = out[:, + max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) diff --git a/lib/torch_utils/ops/upfirdn2d.py b/lib/torch_utils/ops/upfirdn2d.py index ceeac2b9834e33b7c601c28bf27f32aa91c69256..86f6fb36eb83711db42aef6b05c003eceaeeaa69 100644 --- a/lib/torch_utils/ops/upfirdn2d.py +++ b/lib/torch_utils/ops/upfirdn2d.py @@ -5,7 +5,6 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """Custom PyTorch ops for efficient resampling of 2D images.""" import os @@ -23,17 +22,24 @@ from . import conv2d_gradfix _inited = False _plugin = None + def _init(): global _inited, _plugin if not _inited: sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] try: - _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + _plugin = custom_ops.get_plugin( + 'upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'] + ) except: - warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + warnings.warn( + 'Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + + traceback.format_exc() + ) return _plugin is not None + def _parse_scaling(scaling): if isinstance(scaling, int): scaling = [scaling, scaling] @@ -43,6 +49,7 @@ def _parse_scaling(scaling): assert sx >= 1 and sy >= 1 return sx, sy + def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] @@ -54,6 +61,7 @@ def _parse_padding(padding): padx0, padx1, pady0, pady1 = padding return padx0, padx1, pady0, pady1 + def _get_filter_size(f): if f is None: return 1, 1 @@ -67,9 +75,13 @@ def _get_filter_size(f): assert fw >= 1 and fh >= 1 return fw, fh + #---------------------------------------------------------------------------- -def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + +def setup_filter( + f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None +): r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. Args: @@ -111,12 +123,14 @@ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=Fals f /= f.sum() if flip_filter: f = f.flip(list(range(f.ndim))) - f = f * (gain ** (f.ndim / 2)) + f = f * (gain**(f.ndim / 2)) f = f.to(device=device) return f + #---------------------------------------------------------------------------- + def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Pad, upsample, filter, and downsample a batch of 2D images. @@ -160,11 +174,17 @@ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cu assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): - return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) - return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + return _upfirdn2d_cuda( + up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ).apply(x, f) + return _upfirdn2d_ref( + x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ) + #---------------------------------------------------------------------------- + @misc.profiled_function def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. @@ -187,10 +207,12 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): # Pad or crop. x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) - x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + x = x[:, :, + max(-pady0, 0):x.shape[2] - max(-pady1, 0), + max(-padx0, 0):x.shape[3] - max(-padx1, 0)] # Setup filter. - f = f * (gain ** (f.ndim / 2)) + f = f * (gain**(f.ndim / 2)) f = f.to(x.dtype) if not flip_filter: f = f.flip(list(range(f.ndim))) @@ -207,10 +229,12 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): x = x[:, :, ::downy, ::downx] return x + #---------------------------------------------------------------------------- _upfirdn2d_cuda_cache = dict() + def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): """Fast CUDA implementation of `upfirdn2d()` using custom ops. """ @@ -227,23 +251,31 @@ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): # Forward op. class Upfirdn2dCuda(torch.autograd.Function): @staticmethod - def forward(ctx, x, f): # pylint: disable=arguments-differ + def forward(ctx, x, f): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] y = x if f.ndim == 2: - y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + y = _plugin.upfirdn2d( + y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain + ) else: - y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) - y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d( + y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, + np.sqrt(gain) + ) + y = _plugin.upfirdn2d( + y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, + np.sqrt(gain) + ) ctx.save_for_backward(f) ctx.x_shape = x.shape return y @staticmethod - def backward(ctx, dy): # pylint: disable=arguments-differ + def backward(ctx, dy): # pylint: disable=arguments-differ f, = ctx.saved_tensors _, _, ih, iw = ctx.x_shape _, _, oh, ow = dy.shape @@ -258,7 +290,9 @@ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): df = None if ctx.needs_input_grad[0]: - dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + dx = _upfirdn2d_cuda( + up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain + ).apply(dy, f) assert not ctx.needs_input_grad[1] return dx, df @@ -267,8 +301,10 @@ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda return Upfirdn2dCuda + #---------------------------------------------------------------------------- + def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Filter a batch of 2D images using the given 2D FIR filter. @@ -303,8 +339,10 @@ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): ] return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + #---------------------------------------------------------------------------- + def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Upsample a batch of 2D images using the given 2D FIR filter. @@ -340,10 +378,14 @@ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): pady0 + (fh + upy - 1) // 2, pady1 + (fh - upy) // 2, ] - return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + return upfirdn2d( + x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl + ) + #---------------------------------------------------------------------------- + def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Downsample a batch of 2D images using the given 2D FIR filter. @@ -381,4 +423,5 @@ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda' ] return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/persistence.py b/lib/torch_utils/persistence.py index 0186cfd97bca0fcb397a7b73643520c1d1105a02..c3263dc0690ac12d5d2e74a6d9d8d2af2fed0f5b 100644 --- a/lib/torch_utils/persistence.py +++ b/lib/torch_utils/persistence.py @@ -5,7 +5,6 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """Facilities for pickling Python code alongside other data. The pickled code is automatically imported into a separate Python module @@ -24,14 +23,15 @@ import dnnlib #---------------------------------------------------------------------------- -_version = 6 # internal version number -_decorators = set() # {decorator_class, ...} -_import_hooks = [] # [hook_function, ...] +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] _module_to_src_dict = dict() # {module: src, ...} _src_to_module_dict = dict() # {src: module, ...} #---------------------------------------------------------------------------- + def persistent_class(orig_class): r"""Class decorator that extends a given class to save its source code when pickled. @@ -119,18 +119,26 @@ def persistent_class(orig_class): fields = list(super().__reduce__()) fields += [None] * max(3 - len(fields), 0) if fields[0] is not _reconstruct_persistent_obj: - meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) - fields[0] = _reconstruct_persistent_obj # reconstruct func - fields[1] = (meta,) # reconstruct args - fields[2] = None # state dict + meta = dict( + type='class', + version=_version, + module_src=self._orig_module_src, + class_name=self._orig_class_name, + state=fields[2] + ) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta, ) # reconstruct args + fields[2] = None # state dict return tuple(fields) Decorator.__name__ = orig_class.__name__ _decorators.add(Decorator) return Decorator + #---------------------------------------------------------------------------- + def is_persistent(obj): r"""Test whether the given object or class is persistent, i.e., whether it will save its source code when pickled. @@ -140,10 +148,12 @@ def is_persistent(obj): return True except TypeError: pass - return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + #---------------------------------------------------------------------------- + def import_hook(hook): r"""Register an import hook that is called whenever a persistent object is being unpickled. A typical use case is to patch the pickled source @@ -174,8 +184,10 @@ def import_hook(hook): assert callable(hook) _import_hooks.append(hook) + #---------------------------------------------------------------------------- + def _reconstruct_persistent_obj(meta): r"""Hook that is called internally by the `pickle` module to unpickle a persistent object. @@ -196,13 +208,15 @@ def _reconstruct_persistent_obj(meta): setstate = getattr(obj, '__setstate__', None) if callable(setstate): - setstate(meta.state) # pylint: disable=not-callable + setstate(meta.state) # pylint: disable=not-callable else: obj.__dict__.update(meta.state) return obj + #---------------------------------------------------------------------------- + def _module_to_src(module): r"""Query the source code of a given Python module. """ @@ -213,6 +227,7 @@ def _module_to_src(module): _src_to_module_dict[src] = module return src + def _src_to_module(src): r"""Get or create a Python module for the given source code. """ @@ -223,11 +238,13 @@ def _src_to_module(src): sys.modules[module_name] = module _module_to_src_dict[module] = src _src_to_module_dict[src] = module - exec(src, module.__dict__) # pylint: disable=exec-used + exec(src, module.__dict__) # pylint: disable=exec-used return module + #---------------------------------------------------------------------------- + def _check_pickleable(obj): r"""Check that the given object is pickleable, raising an exception if it is not. This function is expected to be considerably more efficient @@ -239,13 +256,15 @@ def _check_pickleable(obj): if isinstance(obj, dict): return [[recurse(x), recurse(y)] for x, y in obj.items()] if isinstance(obj, (str, int, float, bool, bytes, bytearray)): - return None # Python primitive types are pickleable. + return None # Python primitive types are pickleable. if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: - return None # NumPy arrays and PyTorch tensors are pickleable. + return None # NumPy arrays and PyTorch tensors are pickleable. if is_persistent(obj): - return None # Persistent objects are pickleable, by virtue of the constructor check. + return None # Persistent objects are pickleable, by virtue of the constructor check. return obj + with io.BytesIO() as f: pickle.dump(recurse(obj), f) + #---------------------------------------------------------------------------- diff --git a/lib/torch_utils/training_stats.py b/lib/torch_utils/training_stats.py index d2c265f5c8ab235156a4bb12de2df69d00074de5..11658fdbf55450f5f0d4679e247ff65a4b37151e 100644 --- a/lib/torch_utils/training_stats.py +++ b/lib/torch_utils/training_stats.py @@ -5,7 +5,6 @@ # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. - """Facilities for reporting and collecting training statistics across multiple processes and devices. The interface is designed to minimize synchronization overhead as well as the amount of boilerplate in user @@ -20,17 +19,19 @@ from . import misc #---------------------------------------------------------------------------- -_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] -_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. -_counter_dtype = torch.float64 # Data type to use for the internal counters. -_rank = 0 # Rank of the current process. -_sync_device = None # Device to use for multiprocess communication. None = single-process. -_sync_called = False # Has _sync() been called yet? -_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor -_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict( +) # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor #---------------------------------------------------------------------------- + def init_multiprocessing(rank, sync_device): r"""Initializes `torch_utils.training_stats` for collecting statistics across multiple processes. @@ -50,8 +51,10 @@ def init_multiprocessing(rank, sync_device): _rank = rank _sync_device = sync_device + #---------------------------------------------------------------------------- + @misc.profiled_function def report(name, value): r"""Broadcasts the given set of scalars to all interested instances of @@ -98,8 +101,10 @@ def report(name, value): _counters[name][device].add_(moments) return value + #---------------------------------------------------------------------------- + def report0(name, value): r"""Broadcasts the given set of scalars by the first process (`rank = 0`), but ignores any scalars provided by the other processes. @@ -108,8 +113,10 @@ def report0(name, value): report(name, value if _rank == 0 else []) return value + #---------------------------------------------------------------------------- + class Collector: r"""Collects the scalars broadcasted by `report()` and `report0()` and computes their long-term averages (mean and standard deviation) over @@ -220,7 +227,9 @@ class Collector: """ stats = dnnlib.EasyDict() for name in self.names(): - stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + stats[name] = dnnlib.EasyDict( + num=self.num(name), mean=self.mean(name), std=self.std(name) + ) return stats def __getitem__(self, name): @@ -229,8 +238,10 @@ class Collector: """ return self.mean(name) + #---------------------------------------------------------------------------- + def _sync(names): r"""Synchronize the global cumulative counters across devices and processes. Called internally by `Collector.update()`. @@ -265,4 +276,5 @@ def _sync(names): # Return name-value pairs. return [(name, _cumulative[name]) for name in names] + #---------------------------------------------------------------------------- diff --git a/requirements.txt b/requirements.txt index 4faea63e38f40eea0c1f3595a1485ba524c3f73f..f29715018cec6e3a8cea1b896e27bbb0eb05e496 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,19 +3,18 @@ scikit-image trimesh rtree pytorch_lightning -kornia +kornia>0.4.0 chumpy opencv-python opencv_contrib_python scikit-learn protobuf -pymeshlab dataclasses mediapipe einops boto3 +open3d tinyobjloader==2.0.0rc7 git+https://github.com/facebookresearch/pytorch3d.git git+https://github.com/YuliangXiu/neural_voxelization_layer.git git+https://github.com/YuliangXiu/rembg.git -git+https://github.com/mmolero/pypoisson.git