|
import math |
|
import torch |
|
import torch.optim as optim |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import init, Module |
|
import functools |
|
from torch.optim import lr_scheduler |
|
from collections import OrderedDict |
|
import numpy as np |
|
|
|
''' |
|
# =================================== |
|
# Advanced nn.Sequential |
|
# reform nn.Sequentials and nn.Modules |
|
# to a single nn.Sequential |
|
# =================================== |
|
''' |
|
|
|
def seq(*args): |
|
if len(args) == 1: |
|
args = args[0] |
|
if isinstance(args, nn.Module): |
|
return args |
|
modules = OrderedDict() |
|
if isinstance(args, OrderedDict): |
|
for k, v in args.items(): |
|
modules[k] = seq(v) |
|
return nn.Sequential(modules) |
|
assert isinstance(args, (list, tuple)) |
|
return nn.Sequential(*[seq(i) for i in args]) |
|
|
|
''' |
|
# =================================== |
|
# Useful blocks |
|
# -------------------------------- |
|
# conv (+ normaliation + relu) |
|
# concat |
|
# sum |
|
# resblock (ResBlock) |
|
# resdenseblock (ResidualDenseBlock_5C) |
|
# resinresdenseblock (RRDB) |
|
# =================================== |
|
''' |
|
|
|
|
|
|
|
|
|
def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, |
|
output_padding=0, dilation=1, groups=1, bias=True, |
|
padding_mode='zeros', mode='CBR'): |
|
L = [] |
|
for t in mode: |
|
if t == 'C': |
|
L.append(nn.Conv2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
padding_mode=padding_mode)) |
|
elif t == 'X': |
|
assert in_channels == out_channels |
|
L.append(nn.Conv2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=in_channels, |
|
bias=bias, |
|
padding_mode=padding_mode)) |
|
elif t == 'T': |
|
L.append(nn.ConvTranspose2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding, |
|
groups=groups, |
|
bias=bias, |
|
dilation=dilation, |
|
padding_mode=padding_mode)) |
|
elif t == 'B': |
|
L.append(nn.BatchNorm2d(out_channels)) |
|
elif t == 'I': |
|
L.append(nn.InstanceNorm2d(out_channels, affine=True)) |
|
elif t == 'i': |
|
L.append(nn.InstanceNorm2d(out_channels)) |
|
elif t == 'R': |
|
L.append(nn.ReLU(inplace=True)) |
|
elif t == 'r': |
|
L.append(nn.ReLU(inplace=False)) |
|
elif t == 'P': |
|
L.append(nn.PReLU()) |
|
elif t == 'L': |
|
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True)) |
|
elif t == 'l': |
|
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False)) |
|
elif t == '2': |
|
L.append(nn.PixelShuffle(upscale_factor=2)) |
|
elif t == '3': |
|
L.append(nn.PixelShuffle(upscale_factor=3)) |
|
elif t == '4': |
|
L.append(nn.PixelShuffle(upscale_factor=4)) |
|
elif t == 'U': |
|
L.append(nn.Upsample(scale_factor=2, mode='nearest')) |
|
elif t == 'u': |
|
L.append(nn.Upsample(scale_factor=3, mode='nearest')) |
|
elif t == 'M': |
|
L.append(nn.MaxPool2d(kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0)) |
|
elif t == 'A': |
|
L.append(nn.AvgPool2d(kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0)) |
|
else: |
|
raise NotImplementedError('Undefined type: '.format(t)) |
|
return seq(*L) |
|
|
|
|
|
|
|
|
|
class ConcatBlock(nn.Module): |
|
def __init__(self, submodule): |
|
super(ConcatBlock, self).__init__() |
|
|
|
self.sub = submodule |
|
|
|
def forward(self, x): |
|
output = torch.cat((x, self.sub(x)), dim=1) |
|
return output |
|
|
|
def __repr__(self): |
|
return self.sub.__repr__() + '_concat' |
|
|
|
|
|
|
|
|
|
class ShortcutBlock(nn.Module): |
|
def __init__(self, submodule): |
|
super(ShortcutBlock, self).__init__() |
|
|
|
self.sub = submodule |
|
|
|
def forward(self, x): |
|
output = x + self.sub(x) |
|
return output |
|
|
|
def __repr__(self): |
|
tmpstr = 'Identity + \n|' |
|
modstr = self.sub.__repr__().replace('\n', '\n|') |
|
tmpstr = tmpstr + modstr |
|
return tmpstr |
|
|
|
class DWTForward(nn.Module): |
|
def __init__(self): |
|
super(DWTForward, self).__init__() |
|
ll = np.array([[0.5, 0.5], [0.5, 0.5]]) |
|
lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) |
|
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) |
|
hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) |
|
filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1], |
|
hl[None,::-1,::-1], hh[None,::-1,::-1]], |
|
axis=0) |
|
self.weight = nn.Parameter( |
|
torch.tensor(filts).to(torch.get_default_dtype()), |
|
requires_grad=False) |
|
def forward(self, x): |
|
C = x.shape[1] |
|
filters = torch.cat([self.weight,] * C, dim=0) |
|
y = F.conv2d(x, filters, groups=C, stride=2) |
|
return y |
|
|
|
class DWTInverse(nn.Module): |
|
def __init__(self): |
|
super(DWTInverse, self).__init__() |
|
ll = np.array([[0.5, 0.5], [0.5, 0.5]]) |
|
lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) |
|
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) |
|
hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) |
|
filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], |
|
hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], |
|
axis=0) |
|
self.weight = nn.Parameter( |
|
torch.tensor(filts).to(torch.get_default_dtype()), |
|
requires_grad=False) |
|
|
|
def forward(self, x): |
|
C = int(x.shape[1] / 4) |
|
filters = torch.cat([self.weight, ] * C, dim=0) |
|
y = F.conv_transpose2d(x, filters, groups=C, stride=2) |
|
return y |
|
|
|
|
|
|
|
|
|
class CALayer(nn.Module): |
|
def __init__(self, channel=64, reduction=16): |
|
super(CALayer, self).__init__() |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.conv_du = nn.Sequential( |
|
nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
y = self.avg_pool(x) |
|
y = self.conv_du(y) |
|
return x * y |
|
|
|
class ChannelPool(nn.Module): |
|
def forward(self, x): |
|
return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1) |
|
|
|
class spatial_attn_layer(nn.Module): |
|
def __init__(self, kernel_size=3): |
|
super(spatial_attn_layer, self).__init__() |
|
self.compress = ChannelPool() |
|
self.spatial = nn.Conv2d(2, 1, 3, stride=1, padding=1, bias=True) |
|
|
|
def forward(self, x): |
|
|
|
x_compress = self.compress(x) |
|
x_out = self.spatial(x_compress) |
|
scale = torch.sigmoid(x_out) |
|
return x * scale |
|
|
|
|
|
|
|
|
|
class CUCALayer(nn.Module): |
|
def __init__(self, channel=64, min=0, max=None): |
|
super(CUCALayer, self).__init__() |
|
|
|
self.attention = nn.Conv2d(channel, channel, 1, padding=0, |
|
groups=channel, bias=False) |
|
self.min, self.max = min, max |
|
nn.init.uniform_(self.attention.weight, 0, 1) |
|
|
|
def forward(self, x): |
|
self.attention.weight.data.clamp_(self.min, self.max) |
|
return self.attention(x) |
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, |
|
padding=1, bias=True, mode='CRC'): |
|
super(ResBlock, self).__init__() |
|
|
|
assert in_channels == out_channels |
|
if mode[0] in ['R','L']: |
|
mode = mode[0].lower() + mode[1:] |
|
|
|
self.res = conv(in_channels, out_channels, kernel_size, |
|
stride, padding, bias=bias, mode=mode) |
|
|
|
def forward(self, x): |
|
res = self.res(x) |
|
return x + res |
|
|
|
|
|
|
|
|
|
class RCABlock(nn.Module): |
|
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, |
|
padding=1, bias=True, mode='CRC', reduction=16): |
|
super(RCABlock, self).__init__() |
|
assert in_channels == out_channels |
|
if mode[0] in ['R','L']: |
|
mode = mode[0].lower() + mode[1:] |
|
|
|
self.res = conv(in_channels, out_channels, kernel_size, |
|
stride, padding, bias=bias, mode=mode) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
res = self.res(x) |
|
|
|
|
|
|
|
|
|
return res + x |
|
|
|
|
|
|
|
|
|
|
|
class RCAGroup(nn.Module): |
|
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, |
|
padding=1, bias=True, mode='CRC', reduction=16, nb=12, num_attention_block=4, use_attention=True): |
|
super(RCAGroup, self).__init__() |
|
assert in_channels == out_channels |
|
if mode[0] in ['R','L']: |
|
mode = mode[0].lower() + mode[1:] |
|
|
|
RG = [] |
|
for _ in range(num_attention_block): |
|
RG.extend([RCABlock(in_channels, out_channels, kernel_size, stride, padding, |
|
bias, mode, reduction) for _ in range(nb//num_attention_block)]) |
|
if use_attention: |
|
RG.append(AttentionResBlock(in_channels)) |
|
RG.append(conv(out_channels, out_channels, mode='C')) |
|
|
|
|
|
self.rg = nn.Sequential(*RG) |
|
|
|
|
|
def forward(self, x): |
|
res = self.rg(x) |
|
return res + x |
|
|
|
|
|
|
|
|
|
def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, |
|
stride=1, padding=1, bias=True, mode='2R'): |
|
|
|
assert len(mode)<4 and mode[0] in ['2', '3', '4'] |
|
up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, |
|
stride, padding, bias=bias, mode='C'+mode) |
|
return up1 |
|
|
|
|
|
|
|
|
|
|
|
def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, |
|
padding=1, bias=True, mode='2R'): |
|
|
|
assert len(mode)<4 and mode[0] in ['2', '3'] |
|
if mode[0] == '2': |
|
uc = 'UC' |
|
elif mode[0] == '3': |
|
uc = 'uC' |
|
mode = mode.replace(mode[0], uc) |
|
up1 = conv(in_channels, out_channels, kernel_size, stride, |
|
padding, bias=bias, mode=mode) |
|
return up1 |
|
|
|
|
|
|
|
|
|
|
|
def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, |
|
stride=2, padding=0, bias=True, mode='2R'): |
|
|
|
assert len(mode)<4 and mode[0] in ['2', '3', '4'] |
|
kernel_size = int(mode[0]) |
|
stride = int(mode[0]) |
|
mode = mode.replace(mode[0], 'T') |
|
up1 = conv(in_channels, out_channels, kernel_size, stride, |
|
padding, bias=bias, mode=mode) |
|
return up1 |
|
|
|
|
|
''' |
|
# ====================== |
|
# Downsampler |
|
# ====================== |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, |
|
stride=2, padding=0, bias=True, mode='2R'): |
|
|
|
assert len(mode)<4 and mode[0] in ['2', '3', '4'] |
|
kernel_size = int(mode[0]) |
|
stride = int(mode[0]) |
|
mode = mode.replace(mode[0], 'C') |
|
down1 = conv(in_channels, out_channels, kernel_size, stride, |
|
padding, bias=bias, mode=mode) |
|
return down1 |
|
|
|
|
|
|
|
|
|
|
|
def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=0, bias=True, mode='2R'): |
|
|
|
assert len(mode)<4 and mode[0] in ['2', '3'] |
|
kernel_size_pool = int(mode[0]) |
|
stride_pool = int(mode[0]) |
|
mode = mode.replace(mode[0], 'MC') |
|
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0]) |
|
pool_tail = conv(in_channels, out_channels, kernel_size, stride, |
|
padding, bias=bias, mode=mode[1:]) |
|
return sequential(pool, pool_tail) |
|
|
|
|
|
|
|
|
|
|
|
def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, |
|
stride=1, padding=1, bias=True, mode='2R'): |
|
|
|
assert len(mode)<4 and mode[0] in ['2', '3'] |
|
kernel_size_pool = int(mode[0]) |
|
stride_pool = int(mode[0]) |
|
mode = mode.replace(mode[0], 'AC') |
|
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0]) |
|
pool_tail = conv(in_channels, out_channels, kernel_size, stride, |
|
padding, bias=bias, mode=mode[1:]) |
|
return sequential(pool, pool_tail) |
|
|
|
|
|
|
|
class AttentionResBlock(nn.Module): |
|
def __init__(self, dim: int): |
|
super(AttentionResBlock, self).__init__() |
|
self._spatial_attention_conv = nn.Conv2d(2, dim, kernel_size=3, padding=1) |
|
|
|
|
|
self._channel_attention_conv0 = nn.Conv2d(1, dim, kernel_size=1, padding=0) |
|
self._channel_attention_conv1 = nn.Conv2d(dim, dim, kernel_size=1, padding=0) |
|
|
|
self._out_conv = nn.Conv2d(2 * dim, dim, kernel_size=1, padding=0) |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
mean = torch.mean(x, dim=1, keepdim=True) |
|
max, _ = torch.max(x, dim=1, keepdim=True) |
|
spatial_attention = torch.cat([mean, max], dim=1) |
|
spatial_attention = self._spatial_attention_conv(spatial_attention) |
|
spatial_attention = torch.sigmoid(spatial_attention) * x |
|
|
|
channel_attention = torch.relu(self._channel_attention_conv0(mean)) |
|
channel_attention = self._channel_attention_conv1(channel_attention) |
|
channel_attention = torch.sigmoid(channel_attention) * x |
|
|
|
attention = torch.cat([spatial_attention, channel_attention], dim=1) |
|
attention = self._out_conv(attention) |
|
return x + attention |
|
|
|
|
|
class MWRCANv2(nn.Module): |
|
def __init__(self): |
|
super(MWRCANv2, self).__init__() |
|
c1 = 64 |
|
c2 = 96 |
|
c3 = 128 |
|
n_b = 16 |
|
|
|
self.head = seq( |
|
nn.PixelUnshuffle(2), |
|
DWTForward() |
|
) |
|
|
|
self.down1 = seq( |
|
nn.Conv2d(48, c1, 3, 1, 1), |
|
nn.PReLU(), |
|
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, num_attention_block=4) |
|
) |
|
|
|
self.down2 = seq( |
|
DWTForward(), |
|
nn.Conv2d(c1 * 4, c2, 3, 1, 1), |
|
nn.PReLU(), |
|
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, num_attention_block=4) |
|
) |
|
|
|
self.down3 = seq( |
|
DWTForward(), |
|
nn.Conv2d(c2 * 4, c3, 3, 1, 1), |
|
nn.PReLU() |
|
) |
|
|
|
self.middle = seq( |
|
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, num_attention_block=4), |
|
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, num_attention_block=4) |
|
) |
|
|
|
self.up1 = seq( |
|
nn.Conv2d(c3, c2 * 4, 3, 1, 1), |
|
nn.PReLU(), |
|
DWTInverse() |
|
) |
|
|
|
self.up2 = seq( |
|
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, num_attention_block=4), |
|
nn.Conv2d(c2, c1 * 4, 3, 1, 1), |
|
nn.PReLU(), |
|
DWTInverse() |
|
) |
|
|
|
self.up3 = seq( |
|
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, num_attention_block=4), |
|
nn.Conv2d(c1, 48, 3, 1, 1) |
|
) |
|
|
|
self.tail = seq( |
|
DWTInverse(), |
|
nn.PixelShuffle(2) |
|
) |
|
|
|
def forward(self, x, c=None): |
|
c1 = self.head(x) |
|
c2 = self.down1(c1) |
|
c3 = self.down2(c2) |
|
c4 = self.down3(c3) |
|
m = self.middle(c4) |
|
c5 = self.up1(m) + c3 |
|
c6 = self.up2(c5) + c2 |
|
c7 = self.up3(c6) + c1 |
|
out = self.tail(c7) |
|
|
|
return out |
|
|
|
|
|
|
|
class MWRCANv3(nn.Module): |
|
def __init__(self): |
|
super(MWRCANv3, self).__init__() |
|
c1 = 64 |
|
c2 = 96 |
|
c3 = 128 |
|
n_b = 16 |
|
|
|
self.head = seq( |
|
DWTForward() |
|
) |
|
|
|
self.down1 = seq( |
|
nn.Conv2d(12, c1, 3, 1, 1), |
|
nn.PReLU(), |
|
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b) |
|
) |
|
|
|
self.down2 = seq( |
|
DWTForward(), |
|
nn.Conv2d(c1 * 4, c2, 3, 1, 1), |
|
nn.PReLU(), |
|
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b) |
|
) |
|
|
|
self.down3 = seq( |
|
DWTForward(), |
|
nn.Conv2d(c2 * 4, c3, 3, 1, 1), |
|
nn.PReLU() |
|
) |
|
|
|
self.middle = seq( |
|
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b), |
|
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b) |
|
) |
|
|
|
self.up1 = seq( |
|
nn.Conv2d(c3, c2 * 4, 3, 1, 1), |
|
nn.PReLU(), |
|
DWTInverse() |
|
) |
|
|
|
self.up2 = seq( |
|
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b), |
|
nn.Conv2d(c2, c1 * 4, 3, 1, 1), |
|
nn.PReLU(), |
|
DWTInverse() |
|
) |
|
|
|
self.up3 = seq( |
|
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b), |
|
nn.Conv2d(c1, 12, 3, 1, 1) |
|
) |
|
|
|
self.tail = seq( |
|
DWTInverse() |
|
) |
|
|
|
def forward(self, x, c=None): |
|
c1 = self.head(x) |
|
c2 = self.down1(c1) |
|
c3 = self.down2(c2) |
|
c4 = self.down3(c3) |
|
m = self.middle(c4) |
|
c5 = self.up1(m) + c3 |
|
c6 = self.up2(c5) + c2 |
|
c7 = self.up3(c6) + c1 |
|
out = self.tail(c7) |
|
|
|
return out |
|
|
|
|
|
class MWRCANv4(nn.Module): |
|
def __init__(self, c1 = 64, c2 = 96, c3 = 128, n_b = 16): |
|
super(MWRCANv4, self).__init__() |
|
|
|
self.head = seq( |
|
DWTForward() |
|
) |
|
|
|
self.down1 = seq( |
|
nn.Conv2d(12, c1, 3, 1, 1), |
|
nn.PReLU(), |
|
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, use_attention=False) |
|
) |
|
|
|
self.down2 = seq( |
|
DWTForward(), |
|
nn.Conv2d(c1 * 4, c2, 3, 1, 1), |
|
nn.PReLU(), |
|
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, use_attention=False) |
|
) |
|
|
|
self.down3 = seq( |
|
DWTForward(), |
|
nn.Conv2d(c2 * 4, c3, 3, 1, 1), |
|
nn.PReLU() |
|
) |
|
|
|
self.middle = seq( |
|
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, use_attention=False), |
|
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, use_attention=False) |
|
) |
|
|
|
self.up1 = seq( |
|
nn.Conv2d(c3, c2 * 4, 3, 1, 1), |
|
nn.PReLU(), |
|
DWTInverse() |
|
) |
|
|
|
self.up2 = seq( |
|
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, use_attention=False), |
|
nn.Conv2d(c2, c1 * 4, 3, 1, 1), |
|
nn.PReLU(), |
|
DWTInverse() |
|
) |
|
|
|
self.up3 = seq( |
|
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, use_attention=False), |
|
nn.Conv2d(c1, 12, 3, 1, 1) |
|
) |
|
|
|
self.tail = seq( |
|
DWTInverse() |
|
) |
|
|
|
def forward(self, x, c=None): |
|
c1 = self.head(x) |
|
c2 = self.down1(c1) |
|
c3 = self.down2(c2) |
|
c4 = self.down3(c3) |
|
m = self.middle(c4) |
|
c5 = self.up1(m) + c3 |
|
c6 = self.up2(c5) + c2 |
|
c7 = self.up3(c6) |
|
out = self.tail(c7) |
|
|
|
return out |