Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from torch import Tensor | |
class ContentLoss(nn.Module): | |
"""Constructs a content loss function based on the VGG19 network. | |
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image. | |
Paper reference list: | |
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper. | |
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper. | |
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper. | |
""" | |
def __init__(self) -> None: | |
super(ContentLoss, self).__init__() | |
# Load the VGG19 model trained on the ImageNet dataset. | |
vgg19 = models.vgg19(pretrained=True).eval() | |
# Extract the thirty-sixth layer output in the VGG19 model as the content loss. | |
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36]) | |
# Freeze model parameters. | |
for parameters in self.feature_extractor.parameters(): | |
parameters.requires_grad = False | |
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset. | |
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, sr: Tensor, hr: Tensor) -> Tensor: | |
# Standardized operations | |
sr = sr.sub(self.mean).div(self.std) | |
hr = hr.sub(self.mean).div(self.std) | |
# Find the feature map difference between the two images | |
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr)) | |
return loss | |
class GenGaussLoss(nn.Module): | |
def __init__( | |
self, reduction='mean', | |
alpha_eps = 1e-4, beta_eps=1e-4, | |
resi_min = 1e-4, resi_max=1e3 | |
) -> None: | |
super(GenGaussLoss, self).__init__() | |
self.reduction = reduction | |
self.alpha_eps = alpha_eps | |
self.beta_eps = beta_eps | |
self.resi_min = resi_min | |
self.resi_max = resi_max | |
def forward( | |
self, | |
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor | |
): | |
one_over_alpha1 = one_over_alpha + self.alpha_eps | |
beta1 = beta + self.beta_eps | |
resi = torch.abs(mean - target) | |
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max) | |
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max) | |
## check if resi has nans | |
if torch.sum(resi != resi) > 0: | |
print('resi has nans!!') | |
return None | |
log_one_over_alpha = torch.log(one_over_alpha1) | |
log_beta = torch.log(beta1) | |
lgamma_beta = torch.lgamma(torch.pow(beta1, -1)) | |
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0: | |
print('log_one_over_alpha has nan') | |
if torch.sum(lgamma_beta != lgamma_beta) > 0: | |
print('lgamma_beta has nan') | |
if torch.sum(log_beta != log_beta) > 0: | |
print('log_beta has nan') | |
l = resi - log_one_over_alpha + lgamma_beta - log_beta | |
if self.reduction == 'mean': | |
return l.mean() | |
elif self.reduction == 'sum': | |
return l.sum() | |
else: | |
print('Reduction not supported') | |
return None | |
class TempCombLoss(nn.Module): | |
def __init__( | |
self, reduction='mean', | |
alpha_eps = 1e-4, beta_eps=1e-4, | |
resi_min = 1e-4, resi_max=1e3 | |
) -> None: | |
super(TempCombLoss, self).__init__() | |
self.reduction = reduction | |
self.alpha_eps = alpha_eps | |
self.beta_eps = beta_eps | |
self.resi_min = resi_min | |
self.resi_max = resi_max | |
self.L_GenGauss = GenGaussLoss( | |
reduction=self.reduction, | |
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps, | |
resi_min=self.resi_min, resi_max=self.resi_max | |
) | |
self.L_l1 = nn.L1Loss(reduction=self.reduction) | |
def forward( | |
self, | |
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor, | |
T1: float, T2: float | |
): | |
l1 = self.L_l1(mean, target) | |
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target) | |
l = T1*l1 + T2*l2 | |
return l | |
# x1 = torch.randn(4,3,32,32) | |
# x2 = torch.rand(4,3,32,32) | |
# x3 = torch.rand(4,3,32,32) | |
# x4 = torch.randn(4,3,32,32) | |
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3) | |
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3) | |
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2)) |