Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import functools | |
### components | |
class ResConv(nn.Module): | |
""" | |
Residual convolutional block, where | |
convolutional block consists: (convolution => [BN] => ReLU) * 3 | |
residual connection adds the input to the output | |
""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(mid_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(mid_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
self.double_conv1 = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
) | |
def forward(self, x): | |
x_in = self.double_conv1(x) | |
x1 = self.double_conv(x) | |
return self.double_conv(x) + x_in | |
class Down(nn.Module): | |
"""Downscaling with maxpool then Resconv""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.maxpool_conv = nn.Sequential( | |
nn.MaxPool2d(2), | |
ResConv(in_channels, out_channels) | |
) | |
def forward(self, x): | |
return self.maxpool_conv(x) | |
class Up(nn.Module): | |
"""Upscaling then double conv""" | |
def __init__(self, in_channels, out_channels, bilinear=True): | |
super().__init__() | |
# if bilinear, use the normal convolutions to reduce the number of channels | |
if bilinear: | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
self.conv = ResConv(in_channels, out_channels, in_channels // 2) | |
else: | |
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) | |
self.conv = ResConv(in_channels, out_channels) | |
def forward(self, x1, x2): | |
x1 = self.up(x1) | |
# input is CHW | |
diffY = x2.size()[2] - x1.size()[2] | |
diffX = x2.size()[3] - x1.size()[3] | |
x1 = F.pad( | |
x1, | |
[ | |
diffX // 2, diffX - diffX // 2, | |
diffY // 2, diffY - diffY // 2 | |
] | |
) | |
# if you have padding issues, see | |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
x = torch.cat([x2, x1], dim=1) | |
return self.conv(x) | |
class OutConv(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(OutConv, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
def forward(self, x): | |
# return F.relu(self.conv(x)) | |
return self.conv(x) | |
##### The composite networks | |
class UNet(nn.Module): | |
def __init__(self, n_channels, out_channels, bilinear=True): | |
super(UNet, self).__init__() | |
self.n_channels = n_channels | |
self.out_channels = out_channels | |
self.bilinear = bilinear | |
#### | |
self.inc = ResConv(n_channels, 64) | |
self.down1 = Down(64, 128) | |
self.down2 = Down(128, 256) | |
self.down3 = Down(256, 512) | |
factor = 2 if bilinear else 1 | |
self.down4 = Down(512, 1024 // factor) | |
self.up1 = Up(1024, 512 // factor, bilinear) | |
self.up2 = Up(512, 256 // factor, bilinear) | |
self.up3 = Up(256, 128 // factor, bilinear) | |
self.up4 = Up(128, 64, bilinear) | |
self.outc = OutConv(64, out_channels) | |
def forward(self, x): | |
x1 = self.inc(x) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x = self.up1(x5, x4) | |
x = self.up2(x, x3) | |
x = self.up3(x, x2) | |
x = self.up4(x, x1) | |
y = self.outc(x) | |
return y | |
class CasUNet(nn.Module): | |
def __init__(self, n_unet, io_channels, bilinear=True): | |
super(CasUNet, self).__init__() | |
self.n_unet = n_unet | |
self.io_channels = io_channels | |
self.bilinear = bilinear | |
#### | |
self.unet_list = nn.ModuleList() | |
for i in range(self.n_unet): | |
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear)) | |
def forward(self, x, dop=None): | |
y = x | |
for i in range(self.n_unet): | |
if i==0: | |
if dop is not None: | |
y = F.dropout2d(self.unet_list[i](y), p=dop) | |
else: | |
y = self.unet_list[i](y) | |
else: | |
y = self.unet_list[i](y+x) | |
return y | |
class CasUNet_2head(nn.Module): | |
def __init__(self, n_unet, io_channels, bilinear=True): | |
super(CasUNet_2head, self).__init__() | |
self.n_unet = n_unet | |
self.io_channels = io_channels | |
self.bilinear = bilinear | |
#### | |
self.unet_list = nn.ModuleList() | |
for i in range(self.n_unet): | |
if i != self.n_unet-1: | |
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear)) | |
else: | |
self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear)) | |
def forward(self, x): | |
y = x | |
for i in range(self.n_unet): | |
if i==0: | |
y = self.unet_list[i](y) | |
else: | |
y = self.unet_list[i](y+x) | |
y_mean, y_sigma = y[0], y[1] | |
return y_mean, y_sigma | |
class CasUNet_3head(nn.Module): | |
def __init__(self, n_unet, io_channels, bilinear=True): | |
super(CasUNet_3head, self).__init__() | |
self.n_unet = n_unet | |
self.io_channels = io_channels | |
self.bilinear = bilinear | |
#### | |
self.unet_list = nn.ModuleList() | |
for i in range(self.n_unet): | |
if i != self.n_unet-1: | |
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear)) | |
else: | |
self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear)) | |
def forward(self, x): | |
y = x | |
for i in range(self.n_unet): | |
if i==0: | |
y = self.unet_list[i](y) | |
else: | |
y = self.unet_list[i](y+x) | |
y_mean, y_alpha, y_beta = y[0], y[1], y[2] | |
return y_mean, y_alpha, y_beta | |
class UNet_2head(nn.Module): | |
def __init__(self, n_channels, out_channels, bilinear=True): | |
super(UNet_2head, self).__init__() | |
self.n_channels = n_channels | |
self.out_channels = out_channels | |
self.bilinear = bilinear | |
#### | |
self.inc = ResConv(n_channels, 64) | |
self.down1 = Down(64, 128) | |
self.down2 = Down(128, 256) | |
self.down3 = Down(256, 512) | |
factor = 2 if bilinear else 1 | |
self.down4 = Down(512, 1024 // factor) | |
self.up1 = Up(1024, 512 // factor, bilinear) | |
self.up2 = Up(512, 256 // factor, bilinear) | |
self.up3 = Up(256, 128 // factor, bilinear) | |
self.up4 = Up(128, 64, bilinear) | |
#per pixel multiple channels may exist | |
self.out_mean = OutConv(64, out_channels) | |
#variance will always be a single number for a pixel | |
self.out_var = nn.Sequential( | |
OutConv(64, 128), | |
OutConv(128, 1), | |
) | |
def forward(self, x): | |
x1 = self.inc(x) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x = self.up1(x5, x4) | |
x = self.up2(x, x3) | |
x = self.up3(x, x2) | |
x = self.up4(x, x1) | |
y_mean, y_var = self.out_mean(x), self.out_var(x) | |
return y_mean, y_var | |
class UNet_3head(nn.Module): | |
def __init__(self, n_channels, out_channels, bilinear=True): | |
super(UNet_3head, self).__init__() | |
self.n_channels = n_channels | |
self.out_channels = out_channels | |
self.bilinear = bilinear | |
#### | |
self.inc = ResConv(n_channels, 64) | |
self.down1 = Down(64, 128) | |
self.down2 = Down(128, 256) | |
self.down3 = Down(256, 512) | |
factor = 2 if bilinear else 1 | |
self.down4 = Down(512, 1024 // factor) | |
self.up1 = Up(1024, 512 // factor, bilinear) | |
self.up2 = Up(512, 256 // factor, bilinear) | |
self.up3 = Up(256, 128 // factor, bilinear) | |
self.up4 = Up(128, 64, bilinear) | |
#per pixel multiple channels may exist | |
self.out_mean = OutConv(64, out_channels) | |
#variance will always be a single number for a pixel | |
self.out_alpha = nn.Sequential( | |
OutConv(64, 128), | |
OutConv(128, 1), | |
nn.ReLU() | |
) | |
self.out_beta = nn.Sequential( | |
OutConv(64, 128), | |
OutConv(128, 1), | |
nn.ReLU() | |
) | |
def forward(self, x): | |
x1 = self.inc(x) | |
x2 = self.down1(x1) | |
x3 = self.down2(x2) | |
x4 = self.down3(x3) | |
x5 = self.down4(x4) | |
x = self.up1(x5, x4) | |
x = self.up2(x, x3) | |
x = self.up3(x, x2) | |
x = self.up4(x, x1) | |
y_mean, y_alpha, y_beta = self.out_mean(x), \ | |
self.out_alpha(x), self.out_beta(x) | |
return y_mean, y_alpha, y_beta | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_features): | |
super(ResidualBlock, self).__init__() | |
conv_block = [ | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(in_features, in_features, 3), | |
nn.InstanceNorm2d(in_features), | |
nn.ReLU(inplace=True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(in_features, in_features, 3), | |
nn.InstanceNorm2d(in_features) | |
] | |
self.conv_block = nn.Sequential(*conv_block) | |
def forward(self, x): | |
return x + self.conv_block(x) | |
class Generator(nn.Module): | |
def __init__(self, input_nc, output_nc, n_residual_blocks=9): | |
super(Generator, self).__init__() | |
# Initial convolution block | |
model = [ | |
nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), | |
nn.InstanceNorm2d(64), nn.ReLU(inplace=True) | |
] | |
# Downsampling | |
in_features = 64 | |
out_features = in_features*2 | |
for _ in range(2): | |
model += [ | |
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), | |
nn.InstanceNorm2d(out_features), | |
nn.ReLU(inplace=True) | |
] | |
in_features = out_features | |
out_features = in_features*2 | |
# Residual blocks | |
for _ in range(n_residual_blocks): | |
model += [ResidualBlock(in_features)] | |
# Upsampling | |
out_features = in_features//2 | |
for _ in range(2): | |
model += [ | |
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), | |
nn.InstanceNorm2d(out_features), | |
nn.ReLU(inplace=True) | |
] | |
in_features = out_features | |
out_features = in_features//2 | |
# Output layer | |
model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()] | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
return self.model(x) | |
class ResnetGenerator(nn.Module): | |
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. | |
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) | |
""" | |
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): | |
"""Construct a Resnet-based generator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
output_nc (int) -- the number of channels in output images | |
ngf (int) -- the number of filters in the last conv layer | |
norm_layer -- normalization layer | |
use_dropout (bool) -- if use dropout layers | |
n_blocks (int) -- the number of ResNet blocks | |
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero | |
""" | |
assert(n_blocks >= 0) | |
super(ResnetGenerator, self).__init__() | |
if type(norm_layer) == functools.partial: | |
use_bias = norm_layer.func == nn.InstanceNorm2d | |
else: | |
use_bias = norm_layer == nn.InstanceNorm2d | |
model = [nn.ReflectionPad2d(3), | |
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), | |
norm_layer(ngf), | |
nn.ReLU(True)] | |
n_downsampling = 2 | |
for i in range(n_downsampling): # add downsampling layers | |
mult = 2 ** i | |
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), | |
norm_layer(ngf * mult * 2), | |
nn.ReLU(True)] | |
mult = 2 ** n_downsampling | |
for i in range(n_blocks): # add ResNet blocks | |
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] | |
for i in range(n_downsampling): # add upsampling layers | |
mult = 2 ** (n_downsampling - i) | |
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), | |
kernel_size=3, stride=2, | |
padding=1, output_padding=1, | |
bias=use_bias), | |
norm_layer(int(ngf * mult / 2)), | |
nn.ReLU(True)] | |
model += [nn.ReflectionPad2d(3)] | |
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] | |
model += [nn.Tanh()] | |
self.model = nn.Sequential(*model) | |
def forward(self, input): | |
"""Standard forward""" | |
return self.model(input) | |
class ResnetBlock(nn.Module): | |
"""Define a Resnet block""" | |
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): | |
"""Initialize the Resnet block | |
A resnet block is a conv block with skip connections | |
We construct a conv block with build_conv_block function, | |
and implement skip connections in <forward> function. | |
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf | |
""" | |
super(ResnetBlock, self).__init__() | |
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) | |
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): | |
"""Construct a convolutional block. | |
Parameters: | |
dim (int) -- the number of channels in the conv layer. | |
padding_type (str) -- the name of padding layer: reflect | replicate | zero | |
norm_layer -- normalization layer | |
use_dropout (bool) -- if use dropout layers. | |
use_bias (bool) -- if the conv layer uses bias or not | |
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) | |
""" | |
conv_block = [] | |
p = 0 | |
if padding_type == 'reflect': | |
conv_block += [nn.ReflectionPad2d(1)] | |
elif padding_type == 'replicate': | |
conv_block += [nn.ReplicationPad2d(1)] | |
elif padding_type == 'zero': | |
p = 1 | |
else: | |
raise NotImplementedError('padding [%s] is not implemented' % padding_type) | |
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] | |
if use_dropout: | |
conv_block += [nn.Dropout(0.5)] | |
p = 0 | |
if padding_type == 'reflect': | |
conv_block += [nn.ReflectionPad2d(1)] | |
elif padding_type == 'replicate': | |
conv_block += [nn.ReplicationPad2d(1)] | |
elif padding_type == 'zero': | |
p = 1 | |
else: | |
raise NotImplementedError('padding [%s] is not implemented' % padding_type) | |
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] | |
return nn.Sequential(*conv_block) | |
def forward(self, x): | |
"""Forward function (with skip connections)""" | |
out = x + self.conv_block(x) # add skip connections | |
return out | |
### discriminator | |
class NLayerDiscriminator(nn.Module): | |
"""Defines a PatchGAN discriminator""" | |
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): | |
"""Construct a PatchGAN discriminator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
ndf (int) -- the number of filters in the last conv layer | |
n_layers (int) -- the number of conv layers in the discriminator | |
norm_layer -- normalization layer | |
""" | |
super(NLayerDiscriminator, self).__init__() | |
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters | |
use_bias = norm_layer.func == nn.InstanceNorm2d | |
else: | |
use_bias = norm_layer == nn.InstanceNorm2d | |
kw = 4 | |
padw = 1 | |
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] | |
nf_mult = 1 | |
nf_mult_prev = 1 | |
for n in range(1, n_layers): # gradually increase the number of filters | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
nf_mult_prev = nf_mult | |
nf_mult = min(2 ** n_layers, 8) | |
sequence += [ | |
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), | |
norm_layer(ndf * nf_mult), | |
nn.LeakyReLU(0.2, True) | |
] | |
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map | |
self.model = nn.Sequential(*sequence) | |
def forward(self, input): | |
"""Standard forward.""" | |
return self.model(input) |