Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import pytorch_lightning as pl | |
from typing import Any, Dict, Mapping, Tuple | |
from yacs.config import CfgNode | |
from ..utils import SkeletonRenderer, MeshRenderer | |
from ..utils.geometry import aa_to_rotmat, perspective_projection | |
from ..utils.pylogger import get_pylogger | |
from .backbones import create_backbone | |
from .heads import RefineNet | |
from .discriminator import Discriminator | |
from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss | |
from . import MANO | |
log = get_pylogger(__name__) | |
class WiLoR(pl.LightningModule): | |
def __init__(self, cfg: CfgNode, init_renderer: bool = True): | |
""" | |
Setup WiLoR model | |
Args: | |
cfg (CfgNode): Config file as a yacs CfgNode | |
""" | |
super().__init__() | |
# Save hyperparameters | |
self.save_hyperparameters(logger=False, ignore=['init_renderer']) | |
self.cfg = cfg | |
# Create backbone feature extractor | |
self.backbone = create_backbone(cfg) | |
if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): | |
log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') | |
self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'], strict = False) | |
# Create RefineNet head | |
self.refine_net = RefineNet(cfg, feat_dim=1280, upscale=3) | |
# Create discriminator | |
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: | |
self.discriminator = Discriminator() | |
# Define loss functions | |
self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') | |
self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') | |
self.mano_parameter_loss = ParameterLoss() | |
# Instantiate MANO model | |
mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()} | |
self.mano = MANO(**mano_cfg) | |
# Buffer that shows whetheer we need to initialize ActNorm layers | |
self.register_buffer('initialized', torch.tensor(False)) | |
# Setup renderer for visualization | |
if init_renderer: | |
self.renderer = SkeletonRenderer(self.cfg) | |
self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces) | |
else: | |
self.renderer = None | |
self.mesh_renderer = None | |
# Disable automatic optimization since we use adversarial training | |
self.automatic_optimization = False | |
def on_after_backward(self): | |
for name, param in self.named_parameters(): | |
if param.grad is None: | |
print(param.shape) | |
print(name) | |
def get_parameters(self): | |
#all_params = list(self.mano_head.parameters()) | |
all_params = list(self.backbone.parameters()) | |
return all_params | |
def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: | |
""" | |
Setup model and distriminator Optimizers | |
Returns: | |
Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers | |
""" | |
param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}] | |
optimizer = torch.optim.AdamW(params=param_groups, | |
# lr=self.cfg.TRAIN.LR, | |
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) | |
optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(), | |
lr=self.cfg.TRAIN.LR, | |
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) | |
return optimizer, optimizer_disc | |
def forward_step(self, batch: Dict, train: bool = False) -> Dict: | |
""" | |
Run a forward step of the network | |
Args: | |
batch (Dict): Dictionary containing batch data | |
train (bool): Flag indicating whether it is training or validation mode | |
Returns: | |
Dict: Dictionary containing the regression output | |
""" | |
# Use RGB image as input | |
x = batch['img'] | |
batch_size = x.shape[0] | |
# Compute conditioning features using the backbone | |
# if using ViT backbone, we need to use a different aspect ratio | |
temp_mano_params, pred_cam, pred_mano_feats, vit_out = self.backbone(x[:,:,:,32:-32]) # B, 1280, 16, 12 | |
# Compute camera translation | |
device = temp_mano_params['hand_pose'].device | |
dtype = temp_mano_params['hand_pose'].dtype | |
focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype) | |
## Temp MANO | |
temp_mano_params['global_orient'] = temp_mano_params['global_orient'].reshape(batch_size, -1, 3, 3) | |
temp_mano_params['hand_pose'] = temp_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3) | |
temp_mano_params['betas'] = temp_mano_params['betas'].reshape(batch_size, -1) | |
temp_mano_output = self.mano(**{k: v.float() for k,v in temp_mano_params.items()}, pose2rot=False) | |
#temp_keypoints_3d = temp_mano_output.joints | |
temp_vertices = temp_mano_output.vertices | |
pred_mano_params, pred_cam = self.refine_net(vit_out, temp_vertices, pred_cam, pred_mano_feats, focal_length) | |
# Store useful regression outputs to the output dict | |
output = {} | |
output['pred_cam'] = pred_cam | |
output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()} | |
pred_cam_t = torch.stack([pred_cam[:, 1], | |
pred_cam[:, 2], | |
2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1) | |
output['pred_cam_t'] = pred_cam_t | |
output['focal_length'] = focal_length | |
# Compute model vertices, joints and the projected joints | |
pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3) | |
pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3) | |
pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1) | |
mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False) | |
pred_keypoints_3d = mano_output.joints | |
pred_vertices = mano_output.vertices | |
output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3) | |
output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3) | |
pred_cam_t = pred_cam_t.reshape(-1, 3) | |
focal_length = focal_length.reshape(-1, 2) | |
pred_keypoints_2d = perspective_projection(pred_keypoints_3d, | |
translation=pred_cam_t, | |
focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE) | |
output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2) | |
return output | |
def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: | |
""" | |
Compute losses given the input batch and the regression output | |
Args: | |
batch (Dict): Dictionary containing batch data | |
output (Dict): Dictionary containing the regression output | |
train (bool): Flag indicating whether it is training or validation mode | |
Returns: | |
torch.Tensor : Total loss for current batch | |
""" | |
pred_mano_params = output['pred_mano_params'] | |
pred_keypoints_2d = output['pred_keypoints_2d'] | |
pred_keypoints_3d = output['pred_keypoints_3d'] | |
batch_size = pred_mano_params['hand_pose'].shape[0] | |
device = pred_mano_params['hand_pose'].device | |
dtype = pred_mano_params['hand_pose'].dtype | |
# Get annotations | |
gt_keypoints_2d = batch['keypoints_2d'] | |
gt_keypoints_3d = batch['keypoints_3d'] | |
gt_mano_params = batch['mano_params'] | |
has_mano_params = batch['has_mano_params'] | |
is_axis_angle = batch['mano_params_is_axis_angle'] | |
# Compute 3D keypoint loss | |
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) | |
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) | |
# Compute loss on MANO parameters | |
loss_mano_params = {} | |
for k, pred in pred_mano_params.items(): | |
gt = gt_mano_params[k].view(batch_size, -1) | |
if is_axis_angle[k].all(): | |
gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3) | |
has_gt = has_mano_params[k] | |
loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt) | |
loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\ | |
self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\ | |
sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params]) | |
losses = dict(loss=loss.detach(), | |
loss_keypoints_2d=loss_keypoints_2d.detach(), | |
loss_keypoints_3d=loss_keypoints_3d.detach()) | |
for k, v in loss_mano_params.items(): | |
losses['loss_' + k] = v.detach() | |
output['losses'] = losses | |
return loss | |
# Tensoroboard logging should run from first rank only | |
def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None: | |
""" | |
Log results to Tensorboard | |
Args: | |
batch (Dict): Dictionary containing batch data | |
output (Dict): Dictionary containing the regression output | |
step_count (int): Global training step count | |
train (bool): Flag indicating whether it is training or validation mode | |
""" | |
mode = 'train' if train else 'val' | |
batch_size = batch['keypoints_2d'].shape[0] | |
images = batch['img'] | |
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1) | |
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1) | |
#images = 255*images.permute(0, 2, 3, 1).cpu().numpy() | |
pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3) | |
pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3) | |
focal_length = output['focal_length'].detach().reshape(batch_size, 2) | |
gt_keypoints_3d = batch['keypoints_3d'] | |
gt_keypoints_2d = batch['keypoints_2d'] | |
losses = output['losses'] | |
pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3) | |
pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2) | |
if write_to_summary_writer: | |
summary_writer = self.logger.experiment | |
for loss_name, val in losses.items(): | |
summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count) | |
num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES) | |
gt_keypoints_3d = batch['keypoints_3d'] | |
pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3) | |
# We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow. | |
#predictions = self.renderer(pred_keypoints_3d[:num_images], | |
# gt_keypoints_3d[:num_images], | |
# 2 * gt_keypoints_2d[:num_images], | |
# images=images[:num_images], | |
# camera_translation=pred_cam_t[:num_images]) | |
predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(), | |
pred_cam_t[:num_images].cpu().numpy(), | |
images[:num_images].cpu().numpy(), | |
pred_keypoints_2d[:num_images].cpu().numpy(), | |
gt_keypoints_2d[:num_images].cpu().numpy(), | |
focal_length=focal_length[:num_images].cpu().numpy()) | |
if write_to_summary_writer: | |
summary_writer.add_image('%s/predictions' % mode, predictions, step_count) | |
return predictions | |
def forward(self, batch: Dict) -> Dict: | |
""" | |
Run a forward step of the network in val mode | |
Args: | |
batch (Dict): Dictionary containing batch data | |
Returns: | |
Dict: Dictionary containing the regression output | |
""" | |
return self.forward_step(batch, train=False) | |
def training_step_discriminator(self, batch: Dict, | |
hand_pose: torch.Tensor, | |
betas: torch.Tensor, | |
optimizer: torch.optim.Optimizer) -> torch.Tensor: | |
""" | |
Run a discriminator training step | |
Args: | |
batch (Dict): Dictionary containing mocap batch data | |
hand_pose (torch.Tensor): Regressed hand pose from current step | |
betas (torch.Tensor): Regressed betas from current step | |
optimizer (torch.optim.Optimizer): Discriminator optimizer | |
Returns: | |
torch.Tensor: Discriminator loss | |
""" | |
batch_size = hand_pose.shape[0] | |
gt_hand_pose = batch['hand_pose'] | |
gt_betas = batch['betas'] | |
gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3) | |
disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach()) | |
loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size | |
disc_real_out = self.discriminator(gt_rotmat, gt_betas) | |
loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size | |
loss_disc = loss_fake + loss_real | |
loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc | |
optimizer.zero_grad() | |
self.manual_backward(loss) | |
optimizer.step() | |
return loss_disc.detach() | |
def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict: | |
""" | |
Run a full training step | |
Args: | |
joint_batch (Dict): Dictionary containing image and mocap batch data | |
batch_idx (int): Unused. | |
batch_idx (torch.Tensor): Unused. | |
Returns: | |
Dict: Dictionary containing regression output. | |
""" | |
batch = joint_batch['img'] | |
mocap_batch = joint_batch['mocap'] | |
optimizer = self.optimizers(use_pl_optimizer=True) | |
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: | |
optimizer, optimizer_disc = optimizer | |
batch_size = batch['img'].shape[0] | |
output = self.forward_step(batch, train=True) | |
pred_mano_params = output['pred_mano_params'] | |
if self.cfg.get('UPDATE_GT_SPIN', False): | |
self.update_batch_gt_spin(batch, output) | |
loss = self.compute_loss(batch, output, train=True) | |
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: | |
disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1)) | |
loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size | |
loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv | |
# Error if Nan | |
if torch.isnan(loss): | |
raise ValueError('Loss is NaN') | |
optimizer.zero_grad() | |
self.manual_backward(loss) | |
# Clip gradient | |
if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: | |
gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True) | |
self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
optimizer.step() | |
if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: | |
loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc) | |
output['losses']['loss_gen'] = loss_adv | |
output['losses']['loss_disc'] = loss_disc | |
if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: | |
self.tensorboard_logging(batch, output, self.global_step, train=True) | |
self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False) | |
return output | |
def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: | |
""" | |
Run a validation step and log to Tensorboard | |
Args: | |
batch (Dict): Dictionary containing batch data | |
batch_idx (int): Unused. | |
Returns: | |
Dict: Dictionary containing regression output. | |
""" | |
# batch_size = batch['img'].shape[0] | |
output = self.forward_step(batch, train=False) | |
loss = self.compute_loss(batch, output, train=False) | |
output['loss'] = loss | |
self.tensorboard_logging(batch, output, self.global_step, train=False) | |
return output | |