import torch import pytorch_lightning as pl from typing import Any, Dict, Mapping, Tuple from yacs.config import CfgNode from ..utils import SkeletonRenderer, MeshRenderer from ..utils.geometry import aa_to_rotmat, perspective_projection from ..utils.pylogger import get_pylogger from .backbones import create_backbone from .heads import RefineNet from .discriminator import Discriminator from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss from . import MANO log = get_pylogger(__name__) class WiLoR(pl.LightningModule): def __init__(self, cfg: CfgNode, init_renderer: bool = True): """ Setup WiLoR model Args: cfg (CfgNode): Config file as a yacs CfgNode """ super().__init__() # Save hyperparameters self.save_hyperparameters(logger=False, ignore=['init_renderer']) self.cfg = cfg # Create backbone feature extractor self.backbone = create_backbone(cfg) if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'], strict = False) # Create RefineNet head self.refine_net = RefineNet(cfg, feat_dim=1280, upscale=3) # Create discriminator if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: self.discriminator = Discriminator() # Define loss functions self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') self.mano_parameter_loss = ParameterLoss() # Instantiate MANO model mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()} self.mano = MANO(**mano_cfg) # Buffer that shows whetheer we need to initialize ActNorm layers self.register_buffer('initialized', torch.tensor(False)) # Setup renderer for visualization if init_renderer: self.renderer = SkeletonRenderer(self.cfg) self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces) else: self.renderer = None self.mesh_renderer = None # Disable automatic optimization since we use adversarial training self.automatic_optimization = False def on_after_backward(self): for name, param in self.named_parameters(): if param.grad is None: print(param.shape) print(name) def get_parameters(self): #all_params = list(self.mano_head.parameters()) all_params = list(self.backbone.parameters()) return all_params def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: """ Setup model and distriminator Optimizers Returns: Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers """ param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}] optimizer = torch.optim.AdamW(params=param_groups, # lr=self.cfg.TRAIN.LR, weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(), lr=self.cfg.TRAIN.LR, weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) return optimizer, optimizer_disc def forward_step(self, batch: Dict, train: bool = False) -> Dict: """ Run a forward step of the network Args: batch (Dict): Dictionary containing batch data train (bool): Flag indicating whether it is training or validation mode Returns: Dict: Dictionary containing the regression output """ # Use RGB image as input x = batch['img'] batch_size = x.shape[0] # Compute conditioning features using the backbone # if using ViT backbone, we need to use a different aspect ratio temp_mano_params, pred_cam, pred_mano_feats, vit_out = self.backbone(x[:,:,:,32:-32]) # B, 1280, 16, 12 # Compute camera translation device = temp_mano_params['hand_pose'].device dtype = temp_mano_params['hand_pose'].dtype focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype) ## Temp MANO temp_mano_params['global_orient'] = temp_mano_params['global_orient'].reshape(batch_size, -1, 3, 3) temp_mano_params['hand_pose'] = temp_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3) temp_mano_params['betas'] = temp_mano_params['betas'].reshape(batch_size, -1) temp_mano_output = self.mano(**{k: v.float() for k,v in temp_mano_params.items()}, pose2rot=False) #temp_keypoints_3d = temp_mano_output.joints temp_vertices = temp_mano_output.vertices pred_mano_params, pred_cam = self.refine_net(vit_out, temp_vertices, pred_cam, pred_mano_feats, focal_length) # Store useful regression outputs to the output dict output = {} output['pred_cam'] = pred_cam output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()} pred_cam_t = torch.stack([pred_cam[:, 1], pred_cam[:, 2], 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1) output['pred_cam_t'] = pred_cam_t output['focal_length'] = focal_length # Compute model vertices, joints and the projected joints pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3) pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3) pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1) mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False) pred_keypoints_3d = mano_output.joints pred_vertices = mano_output.vertices output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3) output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3) pred_cam_t = pred_cam_t.reshape(-1, 3) focal_length = focal_length.reshape(-1, 2) pred_keypoints_2d = perspective_projection(pred_keypoints_3d, translation=pred_cam_t, focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE) output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2) return output def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: """ Compute losses given the input batch and the regression output Args: batch (Dict): Dictionary containing batch data output (Dict): Dictionary containing the regression output train (bool): Flag indicating whether it is training or validation mode Returns: torch.Tensor : Total loss for current batch """ pred_mano_params = output['pred_mano_params'] pred_keypoints_2d = output['pred_keypoints_2d'] pred_keypoints_3d = output['pred_keypoints_3d'] batch_size = pred_mano_params['hand_pose'].shape[0] device = pred_mano_params['hand_pose'].device dtype = pred_mano_params['hand_pose'].dtype # Get annotations gt_keypoints_2d = batch['keypoints_2d'] gt_keypoints_3d = batch['keypoints_3d'] gt_mano_params = batch['mano_params'] has_mano_params = batch['has_mano_params'] is_axis_angle = batch['mano_params_is_axis_angle'] # Compute 3D keypoint loss loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) # Compute loss on MANO parameters loss_mano_params = {} for k, pred in pred_mano_params.items(): gt = gt_mano_params[k].view(batch_size, -1) if is_axis_angle[k].all(): gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3) has_gt = has_mano_params[k] loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt) loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\ self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\ sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params]) losses = dict(loss=loss.detach(), loss_keypoints_2d=loss_keypoints_2d.detach(), loss_keypoints_3d=loss_keypoints_3d.detach()) for k, v in loss_mano_params.items(): losses['loss_' + k] = v.detach() output['losses'] = losses return loss # Tensoroboard logging should run from first rank only @pl.utilities.rank_zero.rank_zero_only def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None: """ Log results to Tensorboard Args: batch (Dict): Dictionary containing batch data output (Dict): Dictionary containing the regression output step_count (int): Global training step count train (bool): Flag indicating whether it is training or validation mode """ mode = 'train' if train else 'val' batch_size = batch['keypoints_2d'].shape[0] images = batch['img'] images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1) images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1) #images = 255*images.permute(0, 2, 3, 1).cpu().numpy() pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3) pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3) focal_length = output['focal_length'].detach().reshape(batch_size, 2) gt_keypoints_3d = batch['keypoints_3d'] gt_keypoints_2d = batch['keypoints_2d'] losses = output['losses'] pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3) pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2) if write_to_summary_writer: summary_writer = self.logger.experiment for loss_name, val in losses.items(): summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count) num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES) gt_keypoints_3d = batch['keypoints_3d'] pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3) # We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow. #predictions = self.renderer(pred_keypoints_3d[:num_images], # gt_keypoints_3d[:num_images], # 2 * gt_keypoints_2d[:num_images], # images=images[:num_images], # camera_translation=pred_cam_t[:num_images]) predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(), pred_cam_t[:num_images].cpu().numpy(), images[:num_images].cpu().numpy(), pred_keypoints_2d[:num_images].cpu().numpy(), gt_keypoints_2d[:num_images].cpu().numpy(), focal_length=focal_length[:num_images].cpu().numpy()) if write_to_summary_writer: summary_writer.add_image('%s/predictions' % mode, predictions, step_count) return predictions def forward(self, batch: Dict) -> Dict: """ Run a forward step of the network in val mode Args: batch (Dict): Dictionary containing batch data Returns: Dict: Dictionary containing the regression output """ return self.forward_step(batch, train=False) def training_step_discriminator(self, batch: Dict, hand_pose: torch.Tensor, betas: torch.Tensor, optimizer: torch.optim.Optimizer) -> torch.Tensor: """ Run a discriminator training step Args: batch (Dict): Dictionary containing mocap batch data hand_pose (torch.Tensor): Regressed hand pose from current step betas (torch.Tensor): Regressed betas from current step optimizer (torch.optim.Optimizer): Discriminator optimizer Returns: torch.Tensor: Discriminator loss """ batch_size = hand_pose.shape[0] gt_hand_pose = batch['hand_pose'] gt_betas = batch['betas'] gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3) disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach()) loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size disc_real_out = self.discriminator(gt_rotmat, gt_betas) loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size loss_disc = loss_fake + loss_real loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc optimizer.zero_grad() self.manual_backward(loss) optimizer.step() return loss_disc.detach() def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict: """ Run a full training step Args: joint_batch (Dict): Dictionary containing image and mocap batch data batch_idx (int): Unused. batch_idx (torch.Tensor): Unused. Returns: Dict: Dictionary containing regression output. """ batch = joint_batch['img'] mocap_batch = joint_batch['mocap'] optimizer = self.optimizers(use_pl_optimizer=True) if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: optimizer, optimizer_disc = optimizer batch_size = batch['img'].shape[0] output = self.forward_step(batch, train=True) pred_mano_params = output['pred_mano_params'] if self.cfg.get('UPDATE_GT_SPIN', False): self.update_batch_gt_spin(batch, output) loss = self.compute_loss(batch, output, train=True) if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1)) loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv # Error if Nan if torch.isnan(loss): raise ValueError('Loss is NaN') optimizer.zero_grad() self.manual_backward(loss) # Clip gradient if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True) self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True) optimizer.step() if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc) output['losses']['loss_gen'] = loss_adv output['losses']['loss_disc'] = loss_disc if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: self.tensorboard_logging(batch, output, self.global_step, train=True) self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False) return output def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: """ Run a validation step and log to Tensorboard Args: batch (Dict): Dictionary containing batch data batch_idx (int): Unused. Returns: Dict: Dictionary containing regression output. """ # batch_size = batch['img'].shape[0] output = self.forward_step(batch, train=False) loss = self.compute_loss(batch, output, train=False) output['loss'] = loss self.tensorboard_logging(batch, output, self.global_step, train=False) return output