import torch from torch import nn import torch.nn.functional as F from criteria.model_irse import Backbone from criteria.backbones import get_model class IDLoss(nn.Module): """ Computes a cosine similarity between people in two images. Taken from TreB1eN's [1] implementation of InsightFace [2, 3], as used in pixel2style2pixel [4]. [1] https://github.com/TreB1eN/InsightFace_Pytorch [2] https://github.com/deepinsight/insightface [3] Deng, Jiankang and Guo, Jia and Niannan, Xue and Zafeiriou, Stefanos. ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019 [4] https://github.com/eladrich/pixel2style2pixel """ def __init__(self, model_path, official=False): """ Arguments: model_path (str): Path to IR-SE50 model. """ super(IDLoss, self).__init__() print("Loading ResNet ArcFace") self.official = official if official: self.facenet = get_model("r100", fp16=False) else: self.facenet = Backbone( input_size=112, num_layers=50, drop_ratio=0.6, mode="ir_se" ) self.facenet.load_state_dict(torch.load(model_path)) self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) self.facenet.eval() def extract_feats(self, x): x = x[:, :, 35:223, 32:220] # Crop interesting region x = self.face_pool(x) x_feats = self.facenet(x) return x_feats def forward(self, x, y): """ Arguments: x (Tensor): The batch of original images y (Tensor): The batch of generated images Returns: loss (Tensor): Cosine similarity between the features of the original and generated images. """ x_feats = self.extract_feats(x) y_feats = self.extract_feats(y) if self.official: x_feats = F.normalize(x_feats) y_feats = F.normalize(y_feats) loss = (1 - (x_feats * y_feats).sum(dim=1)).mean() return loss