Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from .layers import * | |
class FCDenseNet(nn.Module): | |
def __init__(self, in_channels=3, down_blocks=(5, 5, 5, 5, 5), | |
up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5, | |
growth_rate=16, out_chans_first_conv=48, n_classes=12): | |
super().__init__() | |
self.down_blocks = down_blocks | |
self.up_blocks = up_blocks | |
cur_channels_count = 0 | |
skip_connection_channel_counts = [] | |
## First Convolution ## | |
self.add_module('firstconv', nn.Conv2d(in_channels=in_channels, | |
out_channels=out_chans_first_conv, kernel_size=3, | |
stride=1, padding=1, bias=True)) | |
cur_channels_count = out_chans_first_conv | |
##################### | |
# Downsampling path # | |
##################### | |
self.denseBlocksDown = nn.ModuleList([]) | |
self.transDownBlocks = nn.ModuleList([]) | |
for i in range(len(down_blocks)): | |
self.denseBlocksDown.append( | |
DenseBlock(cur_channels_count, growth_rate, down_blocks[i])) | |
cur_channels_count += (growth_rate*down_blocks[i]) | |
skip_connection_channel_counts.insert(0, cur_channels_count) | |
self.transDownBlocks.append(TransitionDown(cur_channels_count)) | |
##################### | |
# Bottleneck # | |
##################### | |
self.add_module('bottleneck', Bottleneck(cur_channels_count, | |
growth_rate, bottleneck_layers)) | |
prev_block_channels = growth_rate*bottleneck_layers | |
cur_channels_count += prev_block_channels | |
####################### | |
# Upsampling path # | |
####################### | |
self.transUpBlocks = nn.ModuleList([]) | |
self.denseBlocksUp = nn.ModuleList([]) | |
for i in range(len(up_blocks)-1): | |
self.transUpBlocks.append(TransitionUp( | |
prev_block_channels, prev_block_channels)) | |
cur_channels_count = prev_block_channels + \ | |
skip_connection_channel_counts[i] | |
self.denseBlocksUp.append(DenseBlock( | |
cur_channels_count, growth_rate, up_blocks[i], | |
upsample=True)) | |
prev_block_channels = growth_rate*up_blocks[i] | |
cur_channels_count += prev_block_channels | |
## Final DenseBlock ## | |
self.transUpBlocks.append(TransitionUp( | |
prev_block_channels, prev_block_channels)) | |
cur_channels_count = prev_block_channels + \ | |
skip_connection_channel_counts[-1] | |
self.denseBlocksUp.append(DenseBlock( | |
cur_channels_count, growth_rate, up_blocks[-1], | |
upsample=False)) | |
cur_channels_count += growth_rate*up_blocks[-1] | |
## Softmax ## | |
self.finalConv = nn.Conv2d(in_channels=cur_channels_count, | |
out_channels=n_classes, kernel_size=1, stride=1, | |
padding=0, bias=True) | |
self.softmax = nn.LogSoftmax(dim=1) | |
def forward(self, x): | |
out = self.firstconv(x) | |
skip_connections = [] | |
for i in range(len(self.down_blocks)): | |
out = self.denseBlocksDown[i](out) | |
skip_connections.append(out) | |
out = self.transDownBlocks[i](out) | |
out = self.bottleneck(out) | |
for i in range(len(self.up_blocks)): | |
skip = skip_connections.pop() | |
out = self.transUpBlocks[i](out, skip) | |
out = self.denseBlocksUp[i](out) | |
out = self.finalConv(out) | |
out = self.softmax(out) | |
return out | |
def FCDenseNet57(n_classes): | |
return FCDenseNet( | |
in_channels=3, down_blocks=(4, 4, 4, 4, 4), | |
up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4, | |
growth_rate=12, out_chans_first_conv=48, n_classes=n_classes) | |
def FCDenseNet67(n_classes): | |
return FCDenseNet( | |
in_channels=3, down_blocks=(5, 5, 5, 5, 5), | |
up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5, | |
growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) | |
def FCDenseNet103(n_classes): | |
return FCDenseNet( | |
in_channels=3, down_blocks=(4, 5, 7, 10, 12), | |
up_blocks=(12, 10, 7, 5, 4), bottleneck_layers=15, | |
growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) | |