|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import LazyConv3d, MaxPool3d, BatchNorm3d |
|
|
|
from torch.nn.modules import Module |
|
from torch.nn.modules import ReLU |
|
from torch.nn.modules.dropout import Dropout |
|
from torch.nn.modules.instancenorm import InstanceNorm3d |
|
from custom_modules import LazyConvDropoutNormNonlinCat, ModularConvLayers, LazyConvBottleneckLayer |
|
|
|
|
|
class modular_hdunet_encoder(Module): |
|
"""HDUnet encoder with modular parameters |
|
""" |
|
|
|
def __init__(self, base_num_filter, num_blocks_per_stage, num_stages, pool_kernel_sizes, conv_kernel_sizes, |
|
padding='same', conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d, |
|
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d, |
|
pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU): |
|
"""Object creation |
|
|
|
:param base_num_filter: base number of filters (output channels). |
|
:param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage). |
|
:param num_stages: number of stages. |
|
:param pool_kernel_sizes: last conv layer is strided => we use this parameter to set its kernel size and stride |
|
(can be different for each stage). |
|
Please note that this parameter is retrieved in our modular decoder and used as the scale factor (upsampling). |
|
:param conv_kernel_sizes: kernel size (can be different for each stage). |
|
:param padding: padding used, default is 'same'. |
|
:param conv_type: type of convolution used, default is a lazy convolution using: |
|
- dropout; |
|
- normalization; |
|
- nonlinear activation function; |
|
- concatenation. |
|
Must be a torch Module (should be a custom Module). |
|
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
|
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
|
:param dropout_rate: dropout rate used by dropout, default is 0. |
|
:param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
|
:param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module. |
|
:param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2). |
|
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
|
""" |
|
super(modular_hdunet_encoder, self).__init__() |
|
self.base_num_filter = base_num_filter |
|
self.num_blocks_per_stage = num_blocks_per_stage |
|
self.num_stages = num_stages |
|
self.pool_kernel_sizes = pool_kernel_sizes |
|
self.conv_kernel_sizes = conv_kernel_sizes |
|
self.padding = padding |
|
self.conv_type = conv_type |
|
self.norm_type = norm_type |
|
self.dropout_type = dropout_type |
|
self.dropout_rate = dropout_rate |
|
self.nonlin = nonlin |
|
self.expansion_rate = expansion_rate |
|
self.pooling_type = pooling_type |
|
self.pooling_kernel_size = pooling_kernel_size |
|
|
|
self.stages = [] |
|
self.pooling_stages = [] |
|
self.end_stages = [] |
|
self.stage_output_features = [] |
|
self.stage_pool_kernel_size = [] |
|
self.stage_conv_kernel_size = [] |
|
|
|
assert len(pool_kernel_sizes) == len(conv_kernel_sizes) == num_stages |
|
|
|
if not isinstance(num_blocks_per_stage, (list, tuple)): |
|
num_blocks_per_stage = [num_blocks_per_stage] * num_stages |
|
else: |
|
assert len(num_blocks_per_stage) == num_stages |
|
|
|
self.num_blocks_per_stage = num_blocks_per_stage |
|
|
|
current_out_channels = 0 |
|
|
|
for stage in range(num_stages): |
|
current_out_channels = np.round((expansion_rate ** stage) * self.base_num_filter) |
|
current_num_blocks_per_stage = num_blocks_per_stage[stage] |
|
current_pool_kernel_size = pool_kernel_sizes[stage] |
|
current_kernel_size = conv_kernel_sizes[stage] |
|
|
|
current_stage = ModularConvLayers(output_channels=current_out_channels, |
|
num_conv_layers=current_num_blocks_per_stage, |
|
kernel_size=current_kernel_size, padding=padding, conv_type=conv_type, |
|
norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, |
|
nonlin=self.nonlin) |
|
|
|
self.pooling_stages.append(pooling_type(kernel_size=current_pool_kernel_size)) |
|
|
|
|
|
current_end_stage = nn.Sequential( |
|
LazyConv3d(out_channels=current_out_channels, kernel_size=current_pool_kernel_size, |
|
stride=current_pool_kernel_size, padding=0), nonlin(), BatchNorm3d(current_out_channels) |
|
) |
|
|
|
self.stages.append(current_stage) |
|
self.end_stages.append(current_end_stage) |
|
self.stage_output_features.append(current_out_channels) |
|
self.stage_pool_kernel_size.append(current_pool_kernel_size) |
|
self.stage_conv_kernel_size.append(current_kernel_size) |
|
|
|
self.stages = nn.ModuleList(self.stages) |
|
self.pooling_stages = nn.ModuleList(self.pooling_stages) |
|
self.end_stages = nn.ModuleList(self.end_stages) |
|
self.output_features = current_out_channels |
|
|
|
def forward(self, x): |
|
"""Forward inputs through the layer |
|
|
|
:param x: the input to forward. |
|
:return: an array containing the results of the input at each stage of the down-sampling (before concatenation) |
|
which will be used in the decoder later on. The last value of the array is the very last value provided by the |
|
encoder (after concatenation) and will be used in the bottleneck. Therefore, provided x is the number of stages |
|
there are x + 1 values in the array. |
|
""" |
|
skips = [] |
|
|
|
for i, stage in enumerate(self.stages): |
|
x = stage(x) |
|
buff = self.pooling_stages[i](x) |
|
tmp = self.end_stages[i](x) |
|
skips.append(x) |
|
x = torch.cat([tmp, buff], dim=1) |
|
skips.append(x) |
|
|
|
return skips |
|
|
|
|
|
class modular_hdunet_bottleneck(Module): |
|
"""HDUnet bottleneck with modular parameters |
|
""" |
|
|
|
def __init__(self, base_num_filter, num_stages, conv_kernel_sizes, padding='same', num_steps_bottleneck=4, |
|
conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d, |
|
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU): |
|
"""Object creation |
|
|
|
:param base_num_filter: base number of filters (output channels). |
|
:param num_stages: number of stages of the encoder. |
|
:param conv_kernel_sizes: kernel size (can be different for each stage). |
|
:param padding: padding used, default is 'same'. |
|
:param num_steps_bottleneck: number of steps in the bottleneck, default is 4. |
|
:param conv_type: type of convolution used, default is a lazy convolution using: |
|
- dropout; |
|
- normalization; |
|
- nonlinear activation function. |
|
Must be a torch Module (should be a custom Module). |
|
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
|
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
|
:param dropout_rate: dropout rate used by dropout, default is 0. |
|
:param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
|
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
|
""" |
|
super(modular_hdunet_bottleneck, self).__init__() |
|
self.base_num_filter = base_num_filter |
|
self.conv_kernel_sizes = conv_kernel_sizes |
|
self.padding = padding |
|
self.num_steps_bottleneck = num_steps_bottleneck |
|
self.conv_type = conv_type |
|
self.norm_type = norm_type |
|
self.dropout_type = dropout_type |
|
self.dropout_rate = dropout_rate |
|
self.expansion_rate = expansion_rate |
|
self.nonlin = nonlin |
|
|
|
encoder_output_features = (expansion_rate ** num_stages * base_num_filter) |
|
|
|
self.stages = [] |
|
self.step_conv_kernel_size = [] |
|
|
|
assert len(conv_kernel_sizes) == num_steps_bottleneck |
|
|
|
|
|
for step in range(num_steps_bottleneck): |
|
current_kernel_size = conv_kernel_sizes[step] |
|
self.stages.append( |
|
conv_type(output_channels=encoder_output_features, kernel_size=current_kernel_size, padding=padding, |
|
norm_type=norm_type, dropout_type=dropout_type, |
|
dropout_rate=dropout_rate, nonlin=self.nonlin) |
|
) |
|
|
|
self.stages = nn.ModuleList(self.stages) |
|
|
|
|
|
def forward(self, x): |
|
"""Forward inputs through the layer |
|
|
|
:param x: the input to forward. At each step the input is concatenated with |
|
its result in order to produce the input of the next bottleneck layer. |
|
:return: the input forwarded through the layer. |
|
""" |
|
for stage in self.stages: |
|
buff = stage(x) |
|
x = torch.cat([buff, x], dim=1) |
|
return x |
|
|
|
|
|
class modular_hdunet_decoder(Module): |
|
"""HDUnet decoder with modular parameters |
|
""" |
|
|
|
def __init__(self, previous, base_num_filter, num_blocks_per_stage=None, padding='same', |
|
conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d, |
|
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU): |
|
"""Object creation |
|
|
|
:param previous: the encoder which was previously used in the model. It is useful to retrieve some information |
|
that do not change such as the number of stages or the kernel sizes of each stages per example. |
|
:param base_num_filter: base number of filters (output channels). |
|
:param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage). |
|
If set to None, it will be same than the encoder (reversed). |
|
:param padding: padding used, default is 'same'. |
|
:param conv_type: type of convolution used, default is a lazy convolution using: |
|
- dropout; |
|
- normalization; |
|
- nonlinear activation function; |
|
- concatenation. |
|
Must be a torch Module (should be a custom Module). |
|
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
|
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
|
:param dropout_rate: dropout rate used by dropout, default is 0. |
|
:param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
|
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
|
""" |
|
super(modular_hdunet_decoder, self).__init__() |
|
self.base_num_filter = base_num_filter |
|
self.num_blocks_per_stage = num_blocks_per_stage |
|
self.padding = padding |
|
self.conv_type = conv_type |
|
self.norm_type = norm_type |
|
self.dropout_type = dropout_type |
|
self.dropout_rate = dropout_rate |
|
self.expansion_rate = expansion_rate |
|
self.nonlin = nonlin |
|
|
|
|
|
|
|
self.skips = [] |
|
|
|
|
|
|
|
previous_stages = previous.stages |
|
previous_stage_output_features = previous.stage_output_features |
|
previous_stage_pool_kernel_size = previous.stage_pool_kernel_size |
|
previous_stage_conv_kernel_size = previous.stage_conv_kernel_size |
|
|
|
|
|
self.num_stages = len(previous_stages) |
|
|
|
|
|
if num_blocks_per_stage is None: |
|
self.num_blocks_per_stage = previous.num_blocks_per_stage[:][::-1] |
|
|
|
if not isinstance(self.num_blocks_per_stage, (list, tuple)): |
|
self.num_blocks_per_stage = [self.num_blocks_per_stage] * self.num_stages |
|
else: |
|
assert len(self.num_blocks_per_stage) == self.num_stages |
|
|
|
|
|
assert len(self.num_blocks_per_stage) == len(previous.num_blocks_per_stage) |
|
|
|
self.stage_output_features = previous_stage_output_features |
|
self.stage_pool_kernel_size = previous_stage_pool_kernel_size[::-1] |
|
self.stage_conv_kernel_size = previous_stage_conv_kernel_size[::-1] |
|
|
|
self.stages = [] |
|
|
|
number_half_layer = self.num_stages + 1 |
|
|
|
for stage in range(self.num_stages): |
|
current_out_channels = np.round( |
|
(expansion_rate ** (2 * number_half_layer - (stage + number_half_layer) - 1)) * self.base_num_filter) |
|
current_num_blocks_per_stage = self.num_blocks_per_stage[stage] |
|
current_pool_kernel_size = self.stage_pool_kernel_size[stage] |
|
current_kernel_size = self.stage_conv_kernel_size[stage] |
|
self.stages.append( |
|
ModularConvLayers(output_channels=current_out_channels, kernel_size=current_kernel_size, |
|
padding=padding, pool_size=current_pool_kernel_size, conv_type=conv_type, |
|
norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, |
|
num_conv_layers=current_num_blocks_per_stage, nonlin=self.nonlin, upsampling=True)) |
|
|
|
self.stages = nn.ModuleList(self.stages) |
|
|
|
def forward(self, x): |
|
"""Forward inputs through the layer |
|
|
|
:param x: the input to forward. |
|
:return: the input forwarded through the layer. |
|
""" |
|
for i, stage in enumerate(self.stages): |
|
x = stage(x, self.skips[i + 1]) |
|
return x |
|
|
|
def set_skips(self, skips): |
|
self.skips = skips |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class modular_hdunet(Module): |
|
"""HDUnet model with modular parameters |
|
""" |
|
|
|
def __init__(self, base_num_filter, num_blocks_per_stage_encoder, num_stages, |
|
pool_kernel_sizes, conv_kernel_sizes, conv_bottleneck_kernel_sizes, num_blocks_per_stage_decoder=None, |
|
padding='same', num_steps_bottleneck=4, conv_type: Module = LazyConvDropoutNormNonlinCat, |
|
bottleneck_conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d, |
|
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d, |
|
pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU): |
|
"""Object creation |
|
|
|
:param base_num_filter: base number of filters (output channels). |
|
:param num_blocks_per_stage_encoder: number of convolutional block per stage for the encoder |
|
(can be different for each stage). |
|
:param num_stages: number of stages. |
|
:param pool_kernel_sizes: last convolutional layer of the encoder is strided => we use this parameter |
|
to set its kernel size and stride (can be different for each stage). |
|
:param conv_kernel_sizes: kernel size for the encoder and decoder (can be different for each stage). |
|
:param conv_bottleneck_kernel_sizes: kernel size for the bottleneck (can be different for each stage). |
|
:param padding: padding used, default is 'same'. |
|
:param num_blocks_per_stage_decoder: number of convolutional block per stage for the decoder |
|
(can be different for each stage). Default is None (it will be the same as the encoder). |
|
:param num_steps_bottleneck: number of steps in the bottleneck, default is 4. |
|
:param conv_type: type of convolution used, default is a lazy convolution using: |
|
- dropout; |
|
- normalization; |
|
- nonlinear activation function; |
|
- concatenation. |
|
Must be a torch Module (should be a custom Module). |
|
:param bottleneck_conv_type: type of convolution used in the bottleneck, default is a lazy convolution using: |
|
- dropout; |
|
- normalization; |
|
- nonlinear activation function. |
|
Must be a torch Module (should be a custom Module). |
|
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
|
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
|
:param dropout_rate: dropout rate used by dropout, default is 0. |
|
:param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
|
:param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module. |
|
:param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2). |
|
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
|
""" |
|
super(modular_hdunet, self).__init__() |
|
self.nonlin = nonlin |
|
self.encoder = modular_hdunet_encoder(base_num_filter=base_num_filter, |
|
num_blocks_per_stage=num_blocks_per_stage_encoder, num_stages=num_stages, |
|
pool_kernel_sizes=pool_kernel_sizes, conv_kernel_sizes=conv_kernel_sizes, |
|
padding=padding, conv_type=conv_type, norm_type=norm_type, |
|
dropout_type=dropout_type, dropout_rate=dropout_rate, |
|
expansion_rate=expansion_rate, pooling_type=pooling_type, |
|
pooling_kernel_size=pooling_kernel_size, nonlin=self.nonlin) |
|
|
|
self.bottleNeck = modular_hdunet_bottleneck(base_num_filter=base_num_filter, num_stages=num_stages, |
|
conv_kernel_sizes=conv_bottleneck_kernel_sizes, padding=padding, |
|
num_steps_bottleneck=num_steps_bottleneck, |
|
conv_type=bottleneck_conv_type, norm_type=norm_type, |
|
dropout_type=dropout_type, dropout_rate=dropout_rate, |
|
expansion_rate=expansion_rate, nonlin=self.nonlin) |
|
|
|
self.decoder = modular_hdunet_decoder(previous=self.encoder, base_num_filter=base_num_filter, |
|
num_blocks_per_stage=num_blocks_per_stage_decoder, padding=padding, |
|
conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type, |
|
dropout_rate=dropout_rate, expansion_rate=expansion_rate, |
|
nonlin=self.nonlin) |
|
|
|
self.last_block = nn.Sequential( |
|
LazyConv3d(out_channels=1, kernel_size=(3, 3, 3), padding='same'), |
|
nonlin() |
|
) |
|
|
|
def forward(self, x): |
|
"""Forward inputs through the layer |
|
(using the forward functions of the encoder/bottleneck/decoder) |
|
|
|
:param x: the input to forward. |
|
:return: the input forwarded through the layer. |
|
""" |
|
skips = self.encoder(x) |
|
tmp = self.bottleNeck(skips[-1]) |
|
|
|
|
|
|
|
skips = skips[:-1] |
|
skips.append(tmp) |
|
|
|
|
|
skips = skips[::-1] |
|
self.decoder.set_skips(skips) |
|
x = skips[0] |
|
x = self.decoder(x) |
|
return self.last_block(x) |
|
|