WiLoR / wilor /models /losses.py
rolpotamias's picture
Upload 124 files
aea26c8 verified
import torch
import torch.nn as nn
class Keypoint2DLoss(nn.Module):
def __init__(self, loss_type: str = 'l1'):
"""
2D keypoint loss module.
Args:
loss_type (str): Choose between l1 and l2 losses.
"""
super(Keypoint2DLoss, self).__init__()
if loss_type == 'l1':
self.loss_fn = nn.L1Loss(reduction='none')
elif loss_type == 'l2':
self.loss_fn = nn.MSELoss(reduction='none')
else:
raise NotImplementedError('Unsupported loss function')
def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
"""
Compute 2D reprojection loss on the keypoints.
Args:
pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
Returns:
torch.Tensor: 2D keypoint loss.
"""
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
batch_size = conf.shape[0]
loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2))
return loss.sum()
class Keypoint3DLoss(nn.Module):
def __init__(self, loss_type: str = 'l1'):
"""
3D keypoint loss module.
Args:
loss_type (str): Choose between l1 and l2 losses.
"""
super(Keypoint3DLoss, self).__init__()
if loss_type == 'l1':
self.loss_fn = nn.L1Loss(reduction='none')
elif loss_type == 'l2':
self.loss_fn = nn.MSELoss(reduction='none')
else:
raise NotImplementedError('Unsupported loss function')
def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0):
"""
Compute 3D keypoint loss.
Args:
pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
Returns:
torch.Tensor: 3D keypoint loss.
"""
batch_size = pred_keypoints_3d.shape[0]
gt_keypoints_3d = gt_keypoints_3d.clone()
pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2))
return loss.sum()
class ParameterLoss(nn.Module):
def __init__(self):
"""
MANO parameter loss module.
"""
super(ParameterLoss, self).__init__()
self.loss_fn = nn.MSELoss(reduction='none')
def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
"""
Compute MANO parameter loss.
Args:
pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters.
Returns:
torch.Tensor: L2 parameter loss loss.
"""
batch_size = pred_param.shape[0]
num_dims = len(pred_param.shape)
mask_dimension = [batch_size] + [1] * (num_dims-1)
has_param = has_param.type(pred_param.type()).view(*mask_dimension)
loss_param = (has_param * self.loss_fn(pred_param, gt_param))
return loss_param.sum()