NPRC24 / MiAlgo /network.py
Artyom
MiAlgo
82567db verified
raw
history blame
22.6 kB
import math
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init, Module
import functools
from torch.optim import lr_scheduler
from collections import OrderedDict
import numpy as np
'''
# ===================================
# Advanced nn.Sequential
# reform nn.Sequentials and nn.Modules
# to a single nn.Sequential
# ===================================
'''
def seq(*args):
if len(args) == 1:
args = args[0]
if isinstance(args, nn.Module):
return args
modules = OrderedDict()
if isinstance(args, OrderedDict):
for k, v in args.items():
modules[k] = seq(v)
return nn.Sequential(modules)
assert isinstance(args, (list, tuple))
return nn.Sequential(*[seq(i) for i in args])
'''
# ===================================
# Useful blocks
# --------------------------------
# conv (+ normaliation + relu)
# concat
# sum
# resblock (ResBlock)
# resdenseblock (ResidualDenseBlock_5C)
# resinresdenseblock (RRDB)
# ===================================
'''
# -------------------------------------------------------
# return nn.Sequantial of (Conv + BN + ReLU)
# -------------------------------------------------------
def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1,
output_padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', mode='CBR'):
L = []
for t in mode:
if t == 'C':
L.append(nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode))
elif t == 'X':
assert in_channels == out_channels
L.append(nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
bias=bias,
padding_mode=padding_mode))
elif t == 'T':
L.append(nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
padding_mode=padding_mode))
elif t == 'B':
L.append(nn.BatchNorm2d(out_channels))
elif t == 'I':
L.append(nn.InstanceNorm2d(out_channels, affine=True))
elif t == 'i':
L.append(nn.InstanceNorm2d(out_channels))
elif t == 'R':
L.append(nn.ReLU(inplace=True))
elif t == 'r':
L.append(nn.ReLU(inplace=False))
elif t == 'P':
L.append(nn.PReLU())
elif t == 'L':
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True))
elif t == 'l':
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False))
elif t == '2':
L.append(nn.PixelShuffle(upscale_factor=2))
elif t == '3':
L.append(nn.PixelShuffle(upscale_factor=3))
elif t == '4':
L.append(nn.PixelShuffle(upscale_factor=4))
elif t == 'U':
L.append(nn.Upsample(scale_factor=2, mode='nearest'))
elif t == 'u':
L.append(nn.Upsample(scale_factor=3, mode='nearest'))
elif t == 'M':
L.append(nn.MaxPool2d(kernel_size=kernel_size,
stride=stride,
padding=0))
elif t == 'A':
L.append(nn.AvgPool2d(kernel_size=kernel_size,
stride=stride,
padding=0))
else:
raise NotImplementedError('Undefined type: '.format(t))
return seq(*L)
# -------------------------------------------------------
# Concat the output of a submodule to its input
# -------------------------------------------------------
class ConcatBlock(nn.Module):
def __init__(self, submodule):
super(ConcatBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = torch.cat((x, self.sub(x)), dim=1)
return output
def __repr__(self):
return self.sub.__repr__() + '_concat'
# -------------------------------------------------------
# Elementwise sum the output of a submodule to its input
# -------------------------------------------------------
class ShortcutBlock(nn.Module):
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = x + self.sub(x)
return output
def __repr__(self):
tmpstr = 'Identity + \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
class DWTForward(nn.Module):
def __init__(self):
super(DWTForward, self).__init__()
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1],
hl[None,::-1,::-1], hh[None,::-1,::-1]],
axis=0)
self.weight = nn.Parameter(
torch.tensor(filts).to(torch.get_default_dtype()),
requires_grad=False)
def forward(self, x):
C = x.shape[1]
filters = torch.cat([self.weight,] * C, dim=0)
y = F.conv2d(x, filters, groups=C, stride=2)
return y
class DWTInverse(nn.Module):
def __init__(self):
super(DWTInverse, self).__init__()
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1],
hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]],
axis=0)
self.weight = nn.Parameter(
torch.tensor(filts).to(torch.get_default_dtype()),
requires_grad=False)
def forward(self, x):
C = int(x.shape[1] / 4)
filters = torch.cat([self.weight, ] * C, dim=0)
y = F.conv_transpose2d(x, filters, groups=C, stride=2)
return y
# -------------------------------------------------------
# Channel Attention (CA) Layer
# -------------------------------------------------------
class CALayer(nn.Module):
def __init__(self, channel=64, reduction=16):
super(CALayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)
class spatial_attn_layer(nn.Module):
def __init__(self, kernel_size=3):
super(spatial_attn_layer, self).__init__()
self.compress = ChannelPool()
self.spatial = nn.Conv2d(2, 1, 3, stride=1, padding=1, bias=True)
def forward(self, x):
# import pdb;pdb.set_trace()
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
# -------------------------------------------------------
# Content Unrelated Channel Attention (CUCA) Layer
# -------------------------------------------------------
class CUCALayer(nn.Module):
def __init__(self, channel=64, min=0, max=None):
super(CUCALayer, self).__init__()
self.attention = nn.Conv2d(channel, channel, 1, padding=0,
groups=channel, bias=False)
self.min, self.max = min, max
nn.init.uniform_(self.attention.weight, 0, 1)
def forward(self, x):
self.attention.weight.data.clamp_(self.min, self.max)
return self.attention(x)
# -------------------------------------------------------
# Res Block: x + conv(relu(conv(x)))
# -------------------------------------------------------
class ResBlock(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
padding=1, bias=True, mode='CRC'):
super(ResBlock, self).__init__()
assert in_channels == out_channels
if mode[0] in ['R','L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size,
stride, padding, bias=bias, mode=mode)
def forward(self, x):
res = self.res(x)
return x + res
# -------------------------------------------------------
# Residual Channel Attention Block (RCAB)
# -------------------------------------------------------
class RCABlock(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
padding=1, bias=True, mode='CRC', reduction=16):
super(RCABlock, self).__init__()
assert in_channels == out_channels
if mode[0] in ['R','L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size,
stride, padding, bias=bias, mode=mode)
# self.CA = CALayer(out_channels, reduction)
#self.SA = spatial_attn_layer() ## Spatial Attention
#self.conv1x1 = nn.Conv2d(in_channels*2, in_channels, kernel_size=1)
def forward(self, x):
res = self.res(x)
#sa_branch = self.SA(res)
# ca_branch = self.CA(res)
#res = torch.cat([sa_branch, ca_branch], dim=1)
#res = self.conv1x1(res)
return res + x
# -------------------------------------------------------
# Residual Channel Attention Group (RG)
# -------------------------------------------------------
class RCAGroup(nn.Module):
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1,
padding=1, bias=True, mode='CRC', reduction=16, nb=12, num_attention_block=4, use_attention=True):
super(RCAGroup, self).__init__()
assert in_channels == out_channels
if mode[0] in ['R','L']:
mode = mode[0].lower() + mode[1:]
RG = []
for _ in range(num_attention_block):
RG.extend([RCABlock(in_channels, out_channels, kernel_size, stride, padding,
bias, mode, reduction) for _ in range(nb//num_attention_block)])
if use_attention:
RG.append(AttentionResBlock(in_channels))
RG.append(conv(out_channels, out_channels, mode='C'))
# self.rg = ShortcutBlock(nn.Sequential(*RG))
self.rg = nn.Sequential(*RG)
def forward(self, x):
res = self.rg(x)
return res + x
# -------------------------------------------------------
# conv + subp + relu
# -------------------------------------------------------
def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3,
stride=1, padding=1, bias=True, mode='2R'):
# mode examples: 2, 2R, 2BR, 3, ..., 4BR.
assert len(mode)<4 and mode[0] in ['2', '3', '4']
up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size,
stride, padding, bias=bias, mode='C'+mode)
return up1
# -------------------------------------------------------
# nearest_upsample + conv + relu
# -------------------------------------------------------
def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1,
padding=1, bias=True, mode='2R'):
# mode examples: 2, 2R, 2BR, 3, ..., 3BR.
assert len(mode)<4 and mode[0] in ['2', '3']
if mode[0] == '2':
uc = 'UC'
elif mode[0] == '3':
uc = 'uC'
mode = mode.replace(mode[0], uc)
up1 = conv(in_channels, out_channels, kernel_size, stride,
padding, bias=bias, mode=mode)
return up1
# -------------------------------------------------------
# convTranspose + relu
# -------------------------------------------------------
def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2,
stride=2, padding=0, bias=True, mode='2R'):
# mode examples: 2, 2R, 2BR, 3, ..., 4BR.
assert len(mode)<4 and mode[0] in ['2', '3', '4']
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'T')
up1 = conv(in_channels, out_channels, kernel_size, stride,
padding, bias=bias, mode=mode)
return up1
'''
# ======================
# Downsampler
# ======================
'''
# -------------------------------------------------------
# strideconv + relu
# -------------------------------------------------------
def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2,
stride=2, padding=0, bias=True, mode='2R'):
# mode examples: 2, 2R, 2BR, 3, ..., 4BR.
assert len(mode)<4 and mode[0] in ['2', '3', '4']
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'C')
down1 = conv(in_channels, out_channels, kernel_size, stride,
padding, bias=bias, mode=mode)
return down1
# -------------------------------------------------------
# maxpooling + conv + relu
# -------------------------------------------------------
def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3,
stride=1, padding=0, bias=True, mode='2R'):
# mode examples: 2, 2R, 2BR, 3, ..., 3BR.
assert len(mode)<4 and mode[0] in ['2', '3']
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'MC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0])
pool_tail = conv(in_channels, out_channels, kernel_size, stride,
padding, bias=bias, mode=mode[1:])
return sequential(pool, pool_tail)
# -------------------------------------------------------
# averagepooling + conv + relu
# -------------------------------------------------------
def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3,
stride=1, padding=1, bias=True, mode='2R'):
# mode examples: 2, 2R, 2BR, 3, ..., 3BR.
assert len(mode)<4 and mode[0] in ['2', '3']
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'AC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0])
pool_tail = conv(in_channels, out_channels, kernel_size, stride,
padding, bias=bias, mode=mode[1:])
return sequential(pool, pool_tail)
class AttentionResBlock(nn.Module):
def __init__(self, dim: int):
super(AttentionResBlock, self).__init__()
self._spatial_attention_conv = nn.Conv2d(2, dim, kernel_size=3, padding=1)
# Channel attention MLP
self._channel_attention_conv0 = nn.Conv2d(1, dim, kernel_size=1, padding=0)
self._channel_attention_conv1 = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
self._out_conv = nn.Conv2d(2 * dim, dim, kernel_size=1, padding=0)
def forward(self, x: torch.Tensor):
# Spatial attention
mean = torch.mean(x, dim=1, keepdim=True) # Mean/Max on C axis
max, _ = torch.max(x, dim=1, keepdim=True)
spatial_attention = torch.cat([mean, max], dim=1) # [B, 2, H, W]
spatial_attention = self._spatial_attention_conv(spatial_attention)
spatial_attention = torch.sigmoid(spatial_attention) * x
channel_attention = torch.relu(self._channel_attention_conv0(mean))
channel_attention = self._channel_attention_conv1(channel_attention)
channel_attention = torch.sigmoid(channel_attention) * x
attention = torch.cat([spatial_attention, channel_attention], dim=1) # [B, 2*dim, H, W]
attention = self._out_conv(attention)
return x + attention
class MWRCANv2(nn.Module):
def __init__(self):
super(MWRCANv2, self).__init__()
c1 = 64
c2 = 96
c3 = 128
n_b = 16
self.head = seq(
nn.PixelUnshuffle(2),
DWTForward()
)
self.down1 = seq(
nn.Conv2d(48, c1, 3, 1, 1),
nn.PReLU(),
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, num_attention_block=4)
)
self.down2 = seq(
DWTForward(),
nn.Conv2d(c1 * 4, c2, 3, 1, 1),
nn.PReLU(),
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, num_attention_block=4)
)
self.down3 = seq(
DWTForward(),
nn.Conv2d(c2 * 4, c3, 3, 1, 1),
nn.PReLU()
)
self.middle = seq(
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, num_attention_block=4),
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, num_attention_block=4)
)
self.up1 = seq(
nn.Conv2d(c3, c2 * 4, 3, 1, 1),
nn.PReLU(),
DWTInverse()
)
self.up2 = seq(
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, num_attention_block=4),
nn.Conv2d(c2, c1 * 4, 3, 1, 1),
nn.PReLU(),
DWTInverse()
)
self.up3 = seq(
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, num_attention_block=4),
nn.Conv2d(c1, 48, 3, 1, 1)
)
self.tail = seq(
DWTInverse(),
nn.PixelShuffle(2)
)
def forward(self, x, c=None):
c1 = self.head(x)
c2 = self.down1(c1)
c3 = self.down2(c2)
c4 = self.down3(c3)
m = self.middle(c4)
c5 = self.up1(m) + c3
c6 = self.up2(c5) + c2
c7 = self.up3(c6) + c1
out = self.tail(c7)
return out
class MWRCANv3(nn.Module):
def __init__(self):
super(MWRCANv3, self).__init__()
c1 = 64
c2 = 96
c3 = 128
n_b = 16
self.head = seq(
DWTForward()
)
self.down1 = seq(
nn.Conv2d(12, c1, 3, 1, 1),
nn.PReLU(),
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b)
)
self.down2 = seq(
DWTForward(),
nn.Conv2d(c1 * 4, c2, 3, 1, 1),
nn.PReLU(),
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b)
)
self.down3 = seq(
DWTForward(),
nn.Conv2d(c2 * 4, c3, 3, 1, 1),
nn.PReLU()
)
self.middle = seq(
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b),
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b)
)
self.up1 = seq(
nn.Conv2d(c3, c2 * 4, 3, 1, 1),
nn.PReLU(),
DWTInverse()
)
self.up2 = seq(
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b),
nn.Conv2d(c2, c1 * 4, 3, 1, 1),
nn.PReLU(),
DWTInverse()
)
self.up3 = seq(
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b),
nn.Conv2d(c1, 12, 3, 1, 1)
)
self.tail = seq(
DWTInverse()
)
def forward(self, x, c=None):
c1 = self.head(x)
c2 = self.down1(c1)
c3 = self.down2(c2)
c4 = self.down3(c3)
m = self.middle(c4)
c5 = self.up1(m) + c3
c6 = self.up2(c5) + c2
c7 = self.up3(c6) + c1
out = self.tail(c7)
return out
class MWRCANv4(nn.Module):
def __init__(self, c1 = 64, c2 = 96, c3 = 128, n_b = 16):
super(MWRCANv4, self).__init__()
self.head = seq(
DWTForward()
)
self.down1 = seq(
nn.Conv2d(12, c1, 3, 1, 1),
nn.PReLU(),
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, use_attention=False)
)
self.down2 = seq(
DWTForward(),
nn.Conv2d(c1 * 4, c2, 3, 1, 1),
nn.PReLU(),
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, use_attention=False)
)
self.down3 = seq(
DWTForward(),
nn.Conv2d(c2 * 4, c3, 3, 1, 1),
nn.PReLU()
)
self.middle = seq(
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, use_attention=False),
RCAGroup(in_channels=c3, out_channels=c3, nb=n_b, use_attention=False)
)
self.up1 = seq(
nn.Conv2d(c3, c2 * 4, 3, 1, 1),
nn.PReLU(),
DWTInverse()
)
self.up2 = seq(
RCAGroup(in_channels=c2, out_channels=c2, nb=n_b, use_attention=False),
nn.Conv2d(c2, c1 * 4, 3, 1, 1),
nn.PReLU(),
DWTInverse()
)
self.up3 = seq(
RCAGroup(in_channels=c1, out_channels=c1, nb=n_b, use_attention=False),
nn.Conv2d(c1, 12, 3, 1, 1)
)
self.tail = seq(
DWTInverse()
)
def forward(self, x, c=None):
c1 = self.head(x)
c2 = self.down1(c1)
c3 = self.down2(c2)
c4 = self.down3(c3)
m = self.middle(c4)
c5 = self.up1(m) + c3
c6 = self.up2(c5) + c2
c7 = self.up3(c6)
out = self.tail(c7)
return out