import torch | |
import torch.nn as nn | |
from torch import Tensor | |
class MCRMSELoss(nn.Module): | |
def __init__(self): | |
super(MCRMSELoss, self).__init__() | |
self.mse = nn.MSELoss(reduction='none') | |
def forward(self, y_pred: Tensor, y_true: Tensor): | |
"""Calculate mean column-wise rmse on columns | |
:param y_pred: tensor of shape (bs, 6) | |
:param y_true: tensor of shape (bs, 6) | |
:return: tensor of shape 0 (scalar with grad) | |
""" | |
mse = self.mse(y_pred, y_true).mean(0) # column-wise mean | |
rmse = torch.sqrt(mse + 1e-7) | |
return rmse.mean() | |
def class_mcrmse(self, y_pred: Tensor, y_true: Tensor): | |
mse = self.mse(y_pred, y_true).mean(0) # column-wise mean | |
rmse = torch.sqrt(mse + 1e-7) | |
return rmse.squeeze() | |