import torch import torch.nn as nn import torch.nn.functional as F from .vgg import VGG19, VGG16 class Perceptual16Loss(nn.Module): def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): super(Perceptual16Loss, self).__init__() self.vgg = VGG16() self.criterion = torch.nn.L1Loss() self.weights = weights def calculate_pl(self, x, y): feat_output = self.vgg(x) feat_gt = self.vgg(y) content_loss = 0.0 for i in range(3): content_loss += self.criterion(feat_output[i], feat_gt[i]) return content_loss.to(device=x.device) def compute_gram(self, x): b, c, h, w = x.size() f = x.view(b, c, w * h) f_T = f.transpose(1, 2) G = f.bmm(f_T) / (h * w * c) return G def calc_style(self, x, y): feat_output = self.extractor(x) feat_gt = self.extractor(y) style_loss = 0.0 for i in range(3): style_loss += self.criterion( self.compute_gram(feat_output[i]), self.compute_gram(feat_gt[i])) return style_loss class Perceptual19Loss(nn.Module): def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): super(Perceptual19Loss, self).__init__() self.vgg = VGG19() self.criterion = torch.nn.L1Loss() self.weights = weights def calculate_pl(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) content_loss = 0.0 prefix = [1, 2, 3, 4, 5] for i in range(5): content_loss += self.weights[i] * self.criterion( x_vgg[f'relu{prefix[i]}_1'], y_vgg[f'relu{prefix[i]}_1']) return content_loss.to(device=x.device) def compute_gram(self, x): b, c, h, w = x.size() f = x.view(b, c, w * h) f_T = f.transpose(1, 2) G = f.bmm(f_T) / (h * w * c) return G def calc_style(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) style_loss = 0.0 prefix = [2, 3, 4, 5] posfix = [2, 4, 4, 2] for pre, pos in list(zip(prefix, posfix)): style_loss += self.criterion( self.compute_gram(x_vgg[f'relu{pre}_{pos}']), self.compute_gram(y_vgg[f'relu{pre}_{pos}'])) return style_loss