# --- Imports --- # import torch import torch.nn.functional as F # --- Perceptual loss network --- # class LossNetwork(torch.nn.Module): def __init__(self, vgg_model): super(LossNetwork, self).__init__() self.vgg_layers = vgg_model self.layer_name_mapping = { '3': "relu1_2", '8': "relu2_2", '15': "relu3_3" } def output_features(self, x): output = {} for name, module in self.vgg_layers._modules.items(): x = module(x) if name in self.layer_name_mapping: output[self.layer_name_mapping[name]] = x return list(output.values()) def forward(self, dehaze, gt): loss = [] dehaze_features = self.output_features(dehaze) gt_features = self.output_features(gt) for dehaze_feature, gt_feature in zip(dehaze_features, gt_features): loss.append(F.mse_loss(dehaze_feature, gt_feature)) return sum(loss)/len(loss)