File size: 8,276 Bytes
9842c28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
from contextlib import contextmanager
from math import sqrt, log
import torch
import torch.nn as nn
# import warnings
# warnings.simplefilter('ignore')
class BaseModule(nn.Module):
def __init__(self):
self.act_fn = None
super(BaseModule, self).__init__()
def selu_init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) and m.weight.requires_grad:
m.weight.data.normal_(0.0, 1.0 / sqrt(m.weight.numel()))
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear) and m.weight.requires_grad:
m.weight.data.normal_(0, 1.0 / sqrt(m.weight.numel()))
m.bias.data.zero_()
def initialize_weights_xavier_uniform(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) and m.weight.requires_grad:
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad:
m.weight.data.fill_(1)
m.bias.data.zero_()
def load_state_dict(self, state_dict, strict=True, self_state=False):
own_state = self_state if self_state else self.state_dict()
for name, param in state_dict.items():
if name in own_state:
try:
own_state[name].copy_(param.data)
except Exception as e:
print("Parameter {} fails to load.".format(name))
print("-----------------------------------------")
print(e)
else:
print("Parameter {} is not in the model. ".format(name))
@contextmanager
def set_activation_inplace(self):
if hasattr(self, "act_fn") and hasattr(self.act_fn, "inplace"):
# save memory
self.act_fn.inplace = True
yield
self.act_fn.inplace = False
else:
yield
def total_parameters(self):
total = sum([i.numel() for i in self.parameters()])
trainable = sum([i.numel() for i in self.parameters() if i.requires_grad])
print(
"Total parameters : {}. Trainable parameters : {}".format(total, trainable)
)
return total
def forward(self, *x):
raise NotImplementedError
class ResidualFixBlock(BaseModule):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
padding=1,
dilation=1,
groups=1,
activation=nn.SELU(),
conv=nn.Conv2d,
):
super(ResidualFixBlock, self).__init__()
self.act_fn = activation
self.m = nn.Sequential(
conv(
in_channels,
out_channels,
kernel_size,
padding=padding,
dilation=dilation,
groups=groups,
),
activation,
# conv(out_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2, dilation=1, groups=groups),
conv(
in_channels,
out_channels,
kernel_size,
padding=padding,
dilation=dilation,
groups=groups,
),
)
def forward(self, x):
out = self.m(x)
return self.act_fn(out + x)
class ConvBlock(BaseModule):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
padding=1,
dilation=1,
groups=1,
activation=nn.SELU(),
conv=nn.Conv2d,
):
super(ConvBlock, self).__init__()
self.m = nn.Sequential(
conv(
in_channels,
out_channels,
kernel_size,
padding=padding,
dilation=dilation,
groups=groups,
),
activation,
)
def forward(self, x):
return self.m(x)
class UpSampleBlock(BaseModule):
def __init__(self, channels, scale, activation, atrous_rate=1, conv=nn.Conv2d):
assert scale in [2, 4, 8], "Currently UpSampleBlock supports 2, 4, 8 scaling"
super(UpSampleBlock, self).__init__()
m = nn.Sequential(
conv(
channels,
4 * channels,
kernel_size=3,
padding=atrous_rate,
dilation=atrous_rate,
),
activation,
nn.PixelShuffle(2),
)
self.m = nn.Sequential(*[m for _ in range(int(log(scale, 2)))])
def forward(self, x):
return self.m(x)
class SpatialChannelSqueezeExcitation(BaseModule):
# https://arxiv.org/abs/1709.01507
# https://arxiv.org/pdf/1803.02579v1.pdf
def __init__(self, in_channel, reduction=16, activation=nn.ReLU()):
super(SpatialChannelSqueezeExcitation, self).__init__()
linear_nodes = max(in_channel // reduction, 4) # avoid only 1 node case
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.channel_excite = nn.Sequential(
# check the paper for the number 16 in reduction. It is selected by experiment.
nn.Linear(in_channel, linear_nodes),
activation,
nn.Linear(linear_nodes, in_channel),
nn.Sigmoid(),
)
self.spatial_excite = nn.Sequential(
nn.Conv2d(in_channel, 1, kernel_size=1, stride=1, padding=0, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, h, w = x.size()
#
channel = self.avg_pool(x).view(b, c)
# channel = F.avg_pool2d(x, kernel_size=(h,w)).view(b,c) # used for porting to other frameworks
cSE = self.channel_excite(channel).view(b, c, 1, 1)
x_cSE = torch.mul(x, cSE)
# spatial
sSE = self.spatial_excite(x)
x_sSE = torch.mul(x, sSE)
# return x_sSE
return torch.add(x_cSE, x_sSE)
class PartialConv(nn.Module):
# reference:
# Image Inpainting for Irregular Holes Using Partial Convolutions
# http://masc.cs.gmu.edu/wiki/partialconv/show?time=2018-05-24+21%3A41%3A10
# https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/net.py
# https://github.com/SeitaroShinagawa/chainer-partial_convolution_image_inpainting/blob/master/common/net.py
# partial based padding
# https: // github.com / NVIDIA / partialconv / blob / master / models / pd_resnet.py
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super(PartialConv, self).__init__()
self.feature_conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
self.mask_conv = nn.Conv2d(
1, 1, kernel_size, stride, padding, dilation, groups, bias=False
)
self.window_size = self.mask_conv.kernel_size[0] * self.mask_conv.kernel_size[1]
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, x):
output = self.feature_conv(x)
if self.feature_conv.bias is not None:
output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output)
else:
output_bias = torch.zeros_like(output, device=x.device)
with torch.no_grad():
ones = torch.ones(1, 1, x.size(2), x.size(3), device=x.device)
output_mask = self.mask_conv(ones)
output_mask = self.window_size / output_mask
output = (output - output_bias) * output_mask + output_bias
return output
|