WiLoR / wilor /models /wilor.py
rolpotamias's picture
Upload 124 files
aea26c8 verified
import torch
import pytorch_lightning as pl
from typing import Any, Dict, Mapping, Tuple
from yacs.config import CfgNode
from ..utils import SkeletonRenderer, MeshRenderer
from ..utils.geometry import aa_to_rotmat, perspective_projection
from ..utils.pylogger import get_pylogger
from .backbones import create_backbone
from .heads import RefineNet
from .discriminator import Discriminator
from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss
from . import MANO
log = get_pylogger(__name__)
class WiLoR(pl.LightningModule):
def __init__(self, cfg: CfgNode, init_renderer: bool = True):
"""
Setup WiLoR model
Args:
cfg (CfgNode): Config file as a yacs CfgNode
"""
super().__init__()
# Save hyperparameters
self.save_hyperparameters(logger=False, ignore=['init_renderer'])
self.cfg = cfg
# Create backbone feature extractor
self.backbone = create_backbone(cfg)
if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'], strict = False)
# Create RefineNet head
self.refine_net = RefineNet(cfg, feat_dim=1280, upscale=3)
# Create discriminator
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
self.discriminator = Discriminator()
# Define loss functions
self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
self.mano_parameter_loss = ParameterLoss()
# Instantiate MANO model
mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()}
self.mano = MANO(**mano_cfg)
# Buffer that shows whetheer we need to initialize ActNorm layers
self.register_buffer('initialized', torch.tensor(False))
# Setup renderer for visualization
if init_renderer:
self.renderer = SkeletonRenderer(self.cfg)
self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces)
else:
self.renderer = None
self.mesh_renderer = None
# Disable automatic optimization since we use adversarial training
self.automatic_optimization = False
def on_after_backward(self):
for name, param in self.named_parameters():
if param.grad is None:
print(param.shape)
print(name)
def get_parameters(self):
#all_params = list(self.mano_head.parameters())
all_params = list(self.backbone.parameters())
return all_params
def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
"""
Setup model and distriminator Optimizers
Returns:
Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
"""
param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
optimizer = torch.optim.AdamW(params=param_groups,
# lr=self.cfg.TRAIN.LR,
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
lr=self.cfg.TRAIN.LR,
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
return optimizer, optimizer_disc
def forward_step(self, batch: Dict, train: bool = False) -> Dict:
"""
Run a forward step of the network
Args:
batch (Dict): Dictionary containing batch data
train (bool): Flag indicating whether it is training or validation mode
Returns:
Dict: Dictionary containing the regression output
"""
# Use RGB image as input
x = batch['img']
batch_size = x.shape[0]
# Compute conditioning features using the backbone
# if using ViT backbone, we need to use a different aspect ratio
temp_mano_params, pred_cam, pred_mano_feats, vit_out = self.backbone(x[:,:,:,32:-32]) # B, 1280, 16, 12
# Compute camera translation
device = temp_mano_params['hand_pose'].device
dtype = temp_mano_params['hand_pose'].dtype
focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
## Temp MANO
temp_mano_params['global_orient'] = temp_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
temp_mano_params['hand_pose'] = temp_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
temp_mano_params['betas'] = temp_mano_params['betas'].reshape(batch_size, -1)
temp_mano_output = self.mano(**{k: v.float() for k,v in temp_mano_params.items()}, pose2rot=False)
#temp_keypoints_3d = temp_mano_output.joints
temp_vertices = temp_mano_output.vertices
pred_mano_params, pred_cam = self.refine_net(vit_out, temp_vertices, pred_cam, pred_mano_feats, focal_length)
# Store useful regression outputs to the output dict
output = {}
output['pred_cam'] = pred_cam
output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()}
pred_cam_t = torch.stack([pred_cam[:, 1],
pred_cam[:, 2],
2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
output['pred_cam_t'] = pred_cam_t
output['focal_length'] = focal_length
# Compute model vertices, joints and the projected joints
pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1)
mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False)
pred_keypoints_3d = mano_output.joints
pred_vertices = mano_output.vertices
output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
pred_cam_t = pred_cam_t.reshape(-1, 3)
focal_length = focal_length.reshape(-1, 2)
pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
translation=pred_cam_t,
focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
return output
def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
"""
Compute losses given the input batch and the regression output
Args:
batch (Dict): Dictionary containing batch data
output (Dict): Dictionary containing the regression output
train (bool): Flag indicating whether it is training or validation mode
Returns:
torch.Tensor : Total loss for current batch
"""
pred_mano_params = output['pred_mano_params']
pred_keypoints_2d = output['pred_keypoints_2d']
pred_keypoints_3d = output['pred_keypoints_3d']
batch_size = pred_mano_params['hand_pose'].shape[0]
device = pred_mano_params['hand_pose'].device
dtype = pred_mano_params['hand_pose'].dtype
# Get annotations
gt_keypoints_2d = batch['keypoints_2d']
gt_keypoints_3d = batch['keypoints_3d']
gt_mano_params = batch['mano_params']
has_mano_params = batch['has_mano_params']
is_axis_angle = batch['mano_params_is_axis_angle']
# Compute 3D keypoint loss
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
# Compute loss on MANO parameters
loss_mano_params = {}
for k, pred in pred_mano_params.items():
gt = gt_mano_params[k].view(batch_size, -1)
if is_axis_angle[k].all():
gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
has_gt = has_mano_params[k]
loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt)
loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params])
losses = dict(loss=loss.detach(),
loss_keypoints_2d=loss_keypoints_2d.detach(),
loss_keypoints_3d=loss_keypoints_3d.detach())
for k, v in loss_mano_params.items():
losses['loss_' + k] = v.detach()
output['losses'] = losses
return loss
# Tensoroboard logging should run from first rank only
@pl.utilities.rank_zero.rank_zero_only
def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None:
"""
Log results to Tensorboard
Args:
batch (Dict): Dictionary containing batch data
output (Dict): Dictionary containing the regression output
step_count (int): Global training step count
train (bool): Flag indicating whether it is training or validation mode
"""
mode = 'train' if train else 'val'
batch_size = batch['keypoints_2d'].shape[0]
images = batch['img']
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
#images = 255*images.permute(0, 2, 3, 1).cpu().numpy()
pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
focal_length = output['focal_length'].detach().reshape(batch_size, 2)
gt_keypoints_3d = batch['keypoints_3d']
gt_keypoints_2d = batch['keypoints_2d']
losses = output['losses']
pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
if write_to_summary_writer:
summary_writer = self.logger.experiment
for loss_name, val in losses.items():
summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
gt_keypoints_3d = batch['keypoints_3d']
pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
# We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow.
#predictions = self.renderer(pred_keypoints_3d[:num_images],
# gt_keypoints_3d[:num_images],
# 2 * gt_keypoints_2d[:num_images],
# images=images[:num_images],
# camera_translation=pred_cam_t[:num_images])
predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
pred_cam_t[:num_images].cpu().numpy(),
images[:num_images].cpu().numpy(),
pred_keypoints_2d[:num_images].cpu().numpy(),
gt_keypoints_2d[:num_images].cpu().numpy(),
focal_length=focal_length[:num_images].cpu().numpy())
if write_to_summary_writer:
summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
return predictions
def forward(self, batch: Dict) -> Dict:
"""
Run a forward step of the network in val mode
Args:
batch (Dict): Dictionary containing batch data
Returns:
Dict: Dictionary containing the regression output
"""
return self.forward_step(batch, train=False)
def training_step_discriminator(self, batch: Dict,
hand_pose: torch.Tensor,
betas: torch.Tensor,
optimizer: torch.optim.Optimizer) -> torch.Tensor:
"""
Run a discriminator training step
Args:
batch (Dict): Dictionary containing mocap batch data
hand_pose (torch.Tensor): Regressed hand pose from current step
betas (torch.Tensor): Regressed betas from current step
optimizer (torch.optim.Optimizer): Discriminator optimizer
Returns:
torch.Tensor: Discriminator loss
"""
batch_size = hand_pose.shape[0]
gt_hand_pose = batch['hand_pose']
gt_betas = batch['betas']
gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3)
disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach())
loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
disc_real_out = self.discriminator(gt_rotmat, gt_betas)
loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
loss_disc = loss_fake + loss_real
loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
return loss_disc.detach()
def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
"""
Run a full training step
Args:
joint_batch (Dict): Dictionary containing image and mocap batch data
batch_idx (int): Unused.
batch_idx (torch.Tensor): Unused.
Returns:
Dict: Dictionary containing regression output.
"""
batch = joint_batch['img']
mocap_batch = joint_batch['mocap']
optimizer = self.optimizers(use_pl_optimizer=True)
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
optimizer, optimizer_disc = optimizer
batch_size = batch['img'].shape[0]
output = self.forward_step(batch, train=True)
pred_mano_params = output['pred_mano_params']
if self.cfg.get('UPDATE_GT_SPIN', False):
self.update_batch_gt_spin(batch, output)
loss = self.compute_loss(batch, output, train=True)
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1))
loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
# Error if Nan
if torch.isnan(loss):
raise ValueError('Loss is NaN')
optimizer.zero_grad()
self.manual_backward(loss)
# Clip gradient
if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True)
optimizer.step()
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc)
output['losses']['loss_gen'] = loss_adv
output['losses']['loss_disc'] = loss_disc
if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
self.tensorboard_logging(batch, output, self.global_step, train=True)
self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False)
return output
def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
"""
Run a validation step and log to Tensorboard
Args:
batch (Dict): Dictionary containing batch data
batch_idx (int): Unused.
Returns:
Dict: Dictionary containing regression output.
"""
# batch_size = batch['img'].shape[0]
output = self.forward_step(batch, train=False)
loss = self.compute_loss(batch, output, train=False)
output['loss'] = loss
self.tensorboard_logging(batch, output, self.global_step, train=False)
return output