ZhengPeng7's picture
For users to load in one key.
2a41a22
raw
history blame
10.2 kB
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
from config import Config
class Discriminator(nn.Module):
def __init__(self, channels=1, img_size=256):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=Config().batch_size > 1):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.model = nn.Sequential(
*discriminator_block(channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = img_size // 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
class ContourLoss(torch.nn.Module):
def __init__(self):
super(ContourLoss, self).__init__()
def forward(self, pred, target, weight=10):
'''
target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1,
target[:,:,region_out_contour] == 0.
weight: scalar, length term weight.
'''
# length term
delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W)
delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1)
delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2)
delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2)
delta_pred = torch.abs(delta_r + delta_c)
epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice.
length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum.
c_in = torch.ones_like(pred)
c_out = torch.zeros_like(pred)
region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum.
region_out = torch.mean( (1-pred) * (target - c_out)**2 )
region = region_in + region_out
loss = weight * length + region
return loss
class IoULoss(torch.nn.Module):
def __init__(self):
super(IoULoss, self).__init__()
def forward(self, pred, target):
b = pred.shape[0]
IoU = 0.0
for i in range(0, b):
# compute the IoU of the foreground
Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
IoU1 = Iand1 / Ior1
# IoU loss is (1-IoU1)
IoU = IoU + (1-IoU1)
# return IoU/b
return IoU
class StructureLoss(torch.nn.Module):
def __init__(self):
super(StructureLoss, self).__init__()
def forward(self, pred, target):
weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target)
wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
pred = torch.sigmoid(pred)
inter = ((pred * target) * weit).sum(dim=(2, 3))
union = ((pred + target) * weit).sum(dim=(2, 3))
wiou = 1-(inter+1)/(union-inter+1)
return (wbce+wiou).mean()
class PatchIoULoss(torch.nn.Module):
def __init__(self):
super(PatchIoULoss, self).__init__()
self.iou_loss = IoULoss()
def forward(self, pred, target):
win_y, win_x = 64, 64
iou_loss = 0.
for anchor_y in range(0, target.shape[0], win_y):
for anchor_x in range(0, target.shape[1], win_y):
patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
patch_iou_loss = self.iou_loss(patch_pred, patch_target)
iou_loss += patch_iou_loss
return iou_loss
class ThrReg_loss(torch.nn.Module):
def __init__(self):
super(ThrReg_loss, self).__init__()
def forward(self, pred, gt=None):
return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2))
class ClsLoss(nn.Module):
"""
Auxiliary classification loss for each refined class output.
"""
def __init__(self):
super(ClsLoss, self).__init__()
self.config = Config()
self.lambdas_cls = self.config.lambdas_cls
self.criterions_last = {
'ce': nn.CrossEntropyLoss()
}
def forward(self, preds, gt):
loss = 0.
for _, pred_lvl in enumerate(preds):
if pred_lvl is None:
continue
for criterion_name, criterion in self.criterions_last.items():
loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name]
return loss
class PixLoss(nn.Module):
"""
Pixel loss for each refined map output.
"""
def __init__(self):
super(PixLoss, self).__init__()
self.config = Config()
self.lambdas_pix_last = self.config.lambdas_pix_last
self.criterions_last = {}
if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']:
self.criterions_last['bce'] = nn.BCELoss() if not self.config.use_fp16 else nn.BCEWithLogitsLoss()
if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']:
self.criterions_last['iou'] = IoULoss()
if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']:
self.criterions_last['iou_patch'] = PatchIoULoss()
if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']:
self.criterions_last['ssim'] = SSIMLoss()
if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']:
self.criterions_last['mse'] = nn.MSELoss()
if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']:
self.criterions_last['reg'] = ThrReg_loss()
if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']:
self.criterions_last['cnt'] = ContourLoss()
if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']:
self.criterions_last['structure'] = StructureLoss()
def forward(self, scaled_preds, gt):
loss = 0.
criterions_embedded_with_sigmoid = ['structure', ] + ['bce'] if self.config.use_fp16 else []
for _, pred_lvl in enumerate(scaled_preds):
if pred_lvl.shape != gt.shape:
pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True)
for criterion_name, criterion in self.criterions_last.items():
_loss = criterion(pred_lvl.sigmoid() if criterion_name not in criterions_embedded_with_sigmoid else pred_lvl, gt) * self.lambdas_pix_last[criterion_name]
loss += _loss
# print(criterion_name, _loss.item())
return loss
class SSIMLoss(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIMLoss, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return 1 - _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def SSIM(x, y):
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu_x = nn.AvgPool2d(3, 1, 1)(x)
mu_y = nn.AvgPool2d(3, 1, 1)(y)
mu_x_mu_y = mu_x * mu_y
mu_x_sq = mu_x.pow(2)
mu_y_sq = mu_y.pow(2)
sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq
sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq
sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y
SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
SSIM = SSIM_n / SSIM_d
return torch.clamp((1 - SSIM) / 2, 0, 1)
def saliency_structure_consistency(x, y):
ssim = torch.mean(SSIM(x,y))
return ssim