Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class DenseLayer(nn.Sequential): | |
def __init__(self, in_channels, growth_rate): | |
super().__init__() | |
self.add_module('norm', nn.BatchNorm2d(in_channels)) | |
self.add_module('relu', nn.ReLU(True)) | |
self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3, | |
stride=1, padding=1, bias=True)) | |
self.add_module('drop', nn.Dropout2d(0.2)) | |
def forward(self, x): | |
return super().forward(x) | |
class DenseBlock(nn.Module): | |
def __init__(self, in_channels, growth_rate, n_layers, upsample=False): | |
super().__init__() | |
self.upsample = upsample | |
self.layers = nn.ModuleList([DenseLayer( | |
in_channels + i*growth_rate, growth_rate) | |
for i in range(n_layers)]) | |
def forward(self, x): | |
if self.upsample: | |
new_features = [] | |
# we pass all previous activations into each dense layer normally | |
# But we only store each dense layer's output in the new_features array | |
for layer in self.layers: | |
out = layer(x) | |
x = torch.cat([x, out], 1) | |
new_features.append(out) | |
return torch.cat(new_features, 1) | |
else: | |
for layer in self.layers: | |
out = layer(x) | |
x = torch.cat([x, out], 1) # 1 = channel axis | |
return x | |
class TransitionDown(nn.Sequential): | |
def __init__(self, in_channels): | |
super().__init__() | |
self.add_module('norm', nn.BatchNorm2d(num_features=in_channels)) | |
self.add_module('relu', nn.ReLU(inplace=True)) | |
self.add_module('conv', nn.Conv2d(in_channels, in_channels, | |
kernel_size=1, stride=1, | |
padding=0, bias=True)) | |
self.add_module('drop', nn.Dropout2d(0.2)) | |
self.add_module('maxpool', nn.MaxPool2d(2)) | |
def forward(self, x): | |
return super().forward(x) | |
class TransitionUp(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.convTrans = nn.ConvTranspose2d( | |
in_channels=in_channels, out_channels=out_channels, | |
kernel_size=3, stride=2, padding=0, bias=True) | |
def forward(self, x, skip): | |
out = self.convTrans(x) | |
out = center_crop(out, skip.size(2), skip.size(3)) | |
out = torch.cat([out, skip], 1) | |
return out | |
class Bottleneck(nn.Sequential): | |
def __init__(self, in_channels, growth_rate, n_layers): | |
super().__init__() | |
self.add_module('bottleneck', DenseBlock( | |
in_channels, growth_rate, n_layers, upsample=True)) | |
def forward(self, x): | |
return super().forward(x) | |
def center_crop(layer, max_height, max_width): | |
_, _, h, w = layer.size() | |
xy1 = (w - max_width) // 2 | |
xy2 = (h - max_height) // 2 | |
return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] | |