Dose_prediction_han_PBS / modular_hdunet.py
Margerie's picture
Upload 2 files
66427d3
raw
history blame
21.5 kB
import numpy as np
import torch
import torch.nn as nn
from torch.nn import LazyConv3d, MaxPool3d, BatchNorm3d
from torch.nn.modules import Module
from torch.nn.modules import ReLU
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.instancenorm import InstanceNorm3d
from custom_modules import LazyConvDropoutNormNonlinCat, ModularConvLayers, LazyConvBottleneckLayer
class modular_hdunet_encoder(Module):
"""HDUnet encoder with modular parameters
"""
def __init__(self, base_num_filter, num_blocks_per_stage, num_stages, pool_kernel_sizes, conv_kernel_sizes,
padding='same', conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d,
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d,
pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU):
"""Object creation
:param base_num_filter: base number of filters (output channels).
:param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage).
:param num_stages: number of stages.
:param pool_kernel_sizes: last conv layer is strided => we use this parameter to set its kernel size and stride
(can be different for each stage).
Please note that this parameter is retrieved in our modular decoder and used as the scale factor (upsampling).
:param conv_kernel_sizes: kernel size (can be different for each stage).
:param padding: padding used, default is 'same'.
:param conv_type: type of convolution used, default is a lazy convolution using:
- dropout;
- normalization;
- nonlinear activation function;
- concatenation.
Must be a torch Module (should be a custom Module).
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
:param dropout_rate: dropout rate used by dropout, default is 0.
:param expansion_rate: expansion rate used to modify the number of filters, default is 1.
:param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module.
:param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2).
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
"""
super(modular_hdunet_encoder, self).__init__()
self.base_num_filter = base_num_filter
self.num_blocks_per_stage = num_blocks_per_stage
self.num_stages = num_stages
self.pool_kernel_sizes = pool_kernel_sizes
self.conv_kernel_sizes = conv_kernel_sizes
self.padding = padding
self.conv_type = conv_type
self.norm_type = norm_type
self.dropout_type = dropout_type
self.dropout_rate = dropout_rate
self.nonlin = nonlin
self.expansion_rate = expansion_rate
self.pooling_type = pooling_type
self.pooling_kernel_size = pooling_kernel_size
self.stages = []
self.pooling_stages = []
self.end_stages = []
self.stage_output_features = []
self.stage_pool_kernel_size = []
self.stage_conv_kernel_size = []
assert len(pool_kernel_sizes) == len(conv_kernel_sizes) == num_stages
if not isinstance(num_blocks_per_stage, (list, tuple)):
num_blocks_per_stage = [num_blocks_per_stage] * num_stages
else:
assert len(num_blocks_per_stage) == num_stages
self.num_blocks_per_stage = num_blocks_per_stage
current_out_channels = 0
# This is where we manage the number of steps
for stage in range(num_stages):
current_out_channels = np.round((expansion_rate ** stage) * self.base_num_filter)
current_num_blocks_per_stage = num_blocks_per_stage[stage]
current_pool_kernel_size = pool_kernel_sizes[stage]
current_kernel_size = conv_kernel_sizes[stage]
current_stage = ModularConvLayers(output_channels=current_out_channels,
num_conv_layers=current_num_blocks_per_stage,
kernel_size=current_kernel_size, padding=padding, conv_type=conv_type,
norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate,
nonlin=self.nonlin)
self.pooling_stages.append(pooling_type(kernel_size=current_pool_kernel_size))
# BatchNorm3d added statically here (to be similar to the original model)
current_end_stage = nn.Sequential(
LazyConv3d(out_channels=current_out_channels, kernel_size=current_pool_kernel_size,
stride=current_pool_kernel_size, padding=0), nonlin(), BatchNorm3d(current_out_channels)
)
self.stages.append(current_stage)
self.end_stages.append(current_end_stage)
self.stage_output_features.append(current_out_channels)
self.stage_pool_kernel_size.append(current_pool_kernel_size)
self.stage_conv_kernel_size.append(current_kernel_size)
self.stages = nn.ModuleList(self.stages)
self.pooling_stages = nn.ModuleList(self.pooling_stages)
self.end_stages = nn.ModuleList(self.end_stages)
self.output_features = current_out_channels
#self.features_reduction = nn.Conv1d(current_out_channels, current_out_channels//2, 3, stride=2)
def forward(self, x):
"""Forward inputs through the layer
:param x: the input to forward.
:return: an array containing the results of the input at each stage of the down-sampling (before concatenation)
which will be used in the decoder later on. The last value of the array is the very last value provided by the
encoder (after concatenation) and will be used in the bottleneck. Therefore, provided x is the number of stages
there are x + 1 values in the array.
"""
skips = []
for i, stage in enumerate(self.stages):
x = stage(x)
buff = self.pooling_stages[i](x)
tmp = self.end_stages[i](x)
skips.append(x)
x = torch.cat([tmp, buff], dim=1)
skips.append(x)
# skips[-1]=self.features_reduction(skips[-1])
return skips
class modular_hdunet_bottleneck(Module):
"""HDUnet bottleneck with modular parameters
"""
def __init__(self, base_num_filter, num_stages, conv_kernel_sizes, padding='same', num_steps_bottleneck=4,
conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d,
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU):
"""Object creation
:param base_num_filter: base number of filters (output channels).
:param num_stages: number of stages of the encoder.
:param conv_kernel_sizes: kernel size (can be different for each stage).
:param padding: padding used, default is 'same'.
:param num_steps_bottleneck: number of steps in the bottleneck, default is 4.
:param conv_type: type of convolution used, default is a lazy convolution using:
- dropout;
- normalization;
- nonlinear activation function.
Must be a torch Module (should be a custom Module).
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
:param dropout_rate: dropout rate used by dropout, default is 0.
:param expansion_rate: expansion rate used to modify the number of filters, default is 1.
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
"""
super(modular_hdunet_bottleneck, self).__init__()
self.base_num_filter = base_num_filter
self.conv_kernel_sizes = conv_kernel_sizes
self.padding = padding
self.num_steps_bottleneck = num_steps_bottleneck
self.conv_type = conv_type
self.norm_type = norm_type
self.dropout_type = dropout_type
self.dropout_rate = dropout_rate
self.expansion_rate = expansion_rate
self.nonlin = nonlin
encoder_output_features = (expansion_rate ** num_stages * base_num_filter)
self.stages = []
self.step_conv_kernel_size = []
assert len(conv_kernel_sizes) == num_steps_bottleneck
# This is where we manage the number of steps
for step in range(num_steps_bottleneck):
current_kernel_size = conv_kernel_sizes[step]
self.stages.append(
conv_type(output_channels=encoder_output_features, kernel_size=current_kernel_size, padding=padding,
norm_type=norm_type, dropout_type=dropout_type,
dropout_rate=dropout_rate, nonlin=self.nonlin)
)
self.stages = nn.ModuleList(self.stages)
def forward(self, x):
"""Forward inputs through the layer
:param x: the input to forward. At each step the input is concatenated with
its result in order to produce the input of the next bottleneck layer.
:return: the input forwarded through the layer.
"""
for stage in self.stages:
buff = stage(x)
x = torch.cat([buff, x], dim=1)
return x
class modular_hdunet_decoder(Module):
"""HDUnet decoder with modular parameters
"""
def __init__(self, previous, base_num_filter, num_blocks_per_stage=None, padding='same',
conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d,
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU):
"""Object creation
:param previous: the encoder which was previously used in the model. It is useful to retrieve some information
that do not change such as the number of stages or the kernel sizes of each stages per example.
:param base_num_filter: base number of filters (output channels).
:param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage).
If set to None, it will be same than the encoder (reversed).
:param padding: padding used, default is 'same'.
:param conv_type: type of convolution used, default is a lazy convolution using:
- dropout;
- normalization;
- nonlinear activation function;
- concatenation.
Must be a torch Module (should be a custom Module).
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
:param dropout_rate: dropout rate used by dropout, default is 0.
:param expansion_rate: expansion rate used to modify the number of filters, default is 1.
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
"""
super(modular_hdunet_decoder, self).__init__()
self.base_num_filter = base_num_filter
self.num_blocks_per_stage = num_blocks_per_stage
self.padding = padding
self.conv_type = conv_type
self.norm_type = norm_type
self.dropout_type = dropout_type
self.dropout_rate = dropout_rate
self.expansion_rate = expansion_rate
self.nonlin = nonlin
# We had to provide the skips using the set function since we are using Lazy layer and torchsummary does not
# allow us to use an array as a parameter for the forward function.
self.skips = []
# We retrieve the 'architectural' information that were provided to the encoder
# in order to have a consistent decoder
previous_stages = previous.stages
previous_stage_output_features = previous.stage_output_features
previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
previous_stage_conv_kernel_size = previous.stage_conv_kernel_size
# We have the same as the first stage given that bottleneck is done separately
self.num_stages = len(previous_stages)
# If num_blocks_per_stage is set to None, it will be same than the encoder (reversed).
if num_blocks_per_stage is None:
self.num_blocks_per_stage = previous.num_blocks_per_stage[:][::-1]
if not isinstance(self.num_blocks_per_stage, (list, tuple)):
self.num_blocks_per_stage = [self.num_blocks_per_stage] * self.num_stages
else:
assert len(self.num_blocks_per_stage) == self.num_stages
# There should be the same number of stages since we are doing the bottleneck and the encoder parts separately
assert len(self.num_blocks_per_stage) == len(previous.num_blocks_per_stage)
self.stage_output_features = previous_stage_output_features
self.stage_pool_kernel_size = previous_stage_pool_kernel_size[::-1]
self.stage_conv_kernel_size = previous_stage_conv_kernel_size[::-1]
self.stages = []
number_half_layer = self.num_stages + 1
# This is where we manage the number of steps
for stage in range(self.num_stages):
current_out_channels = np.round(
(expansion_rate ** (2 * number_half_layer - (stage + number_half_layer) - 1)) * self.base_num_filter)
current_num_blocks_per_stage = self.num_blocks_per_stage[stage]
current_pool_kernel_size = self.stage_pool_kernel_size[stage]
current_kernel_size = self.stage_conv_kernel_size[stage]
self.stages.append(
ModularConvLayers(output_channels=current_out_channels, kernel_size=current_kernel_size,
padding=padding, pool_size=current_pool_kernel_size, conv_type=conv_type,
norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate,
num_conv_layers=current_num_blocks_per_stage, nonlin=self.nonlin, upsampling=True))
self.stages = nn.ModuleList(self.stages)
def forward(self, x):
"""Forward inputs through the layer
:param x: the input to forward.
:return: the input forwarded through the layer.
"""
for i, stage in enumerate(self.stages):
x = stage(x, self.skips[i + 1])
return x
def set_skips(self, skips):
self.skips = skips
# We did our best we could to allow a maximum of modularity while keeping a certain sense in the parameters
# we propose to modify. Nevertheless, we cannot guarantee that the model will work no matter what parameters you pass.
# So if you change some parameters and the result is not what you expected, be careful to understand how it works
# If you want to change the type of convolutional layer used, we advise you to check how the existing ones have
# been implemented.
# “With great power comes great responsibility” Uncle Ben.
class modular_hdunet(Module):
"""HDUnet model with modular parameters
"""
def __init__(self, base_num_filter, num_blocks_per_stage_encoder, num_stages,
pool_kernel_sizes, conv_kernel_sizes, conv_bottleneck_kernel_sizes, num_blocks_per_stage_decoder=None,
padding='same', num_steps_bottleneck=4, conv_type: Module = LazyConvDropoutNormNonlinCat,
bottleneck_conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d,
dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d,
pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU):
"""Object creation
:param base_num_filter: base number of filters (output channels).
:param num_blocks_per_stage_encoder: number of convolutional block per stage for the encoder
(can be different for each stage).
:param num_stages: number of stages.
:param pool_kernel_sizes: last convolutional layer of the encoder is strided => we use this parameter
to set its kernel size and stride (can be different for each stage).
:param conv_kernel_sizes: kernel size for the encoder and decoder (can be different for each stage).
:param conv_bottleneck_kernel_sizes: kernel size for the bottleneck (can be different for each stage).
:param padding: padding used, default is 'same'.
:param num_blocks_per_stage_decoder: number of convolutional block per stage for the decoder
(can be different for each stage). Default is None (it will be the same as the encoder).
:param num_steps_bottleneck: number of steps in the bottleneck, default is 4.
:param conv_type: type of convolution used, default is a lazy convolution using:
- dropout;
- normalization;
- nonlinear activation function;
- concatenation.
Must be a torch Module (should be a custom Module).
:param bottleneck_conv_type: type of convolution used in the bottleneck, default is a lazy convolution using:
- dropout;
- normalization;
- nonlinear activation function.
Must be a torch Module (should be a custom Module).
:param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
:param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
:param dropout_rate: dropout rate used by dropout, default is 0.
:param expansion_rate: expansion rate used to modify the number of filters, default is 1.
:param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module.
:param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2).
:param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
"""
super(modular_hdunet, self).__init__()
self.nonlin = nonlin
self.encoder = modular_hdunet_encoder(base_num_filter=base_num_filter,
num_blocks_per_stage=num_blocks_per_stage_encoder, num_stages=num_stages,
pool_kernel_sizes=pool_kernel_sizes, conv_kernel_sizes=conv_kernel_sizes,
padding=padding, conv_type=conv_type, norm_type=norm_type,
dropout_type=dropout_type, dropout_rate=dropout_rate,
expansion_rate=expansion_rate, pooling_type=pooling_type,
pooling_kernel_size=pooling_kernel_size, nonlin=self.nonlin)
self.bottleNeck = modular_hdunet_bottleneck(base_num_filter=base_num_filter, num_stages=num_stages,
conv_kernel_sizes=conv_bottleneck_kernel_sizes, padding=padding,
num_steps_bottleneck=num_steps_bottleneck,
conv_type=bottleneck_conv_type, norm_type=norm_type,
dropout_type=dropout_type, dropout_rate=dropout_rate,
expansion_rate=expansion_rate, nonlin=self.nonlin)
self.decoder = modular_hdunet_decoder(previous=self.encoder, base_num_filter=base_num_filter,
num_blocks_per_stage=num_blocks_per_stage_decoder, padding=padding,
conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type,
dropout_rate=dropout_rate, expansion_rate=expansion_rate,
nonlin=self.nonlin)
self.last_block = nn.Sequential(
LazyConv3d(out_channels=1, kernel_size=(3, 3, 3), padding='same'),
nonlin()
)
def forward(self, x):
"""Forward inputs through the layer
(using the forward functions of the encoder/bottleneck/decoder)
:param x: the input to forward.
:return: the input forwarded through the layer.
"""
skips = self.encoder(x)
tmp = self.bottleNeck(skips[-1])
# After providing the last value of skips to the bottleneck,
# we replace it with the value computed in the bottleneck
skips = skips[:-1]
skips.append(tmp)
# Since the first value that'll be used in the decoder is actually the last one of the array, we reverse it.
skips = skips[::-1]
self.decoder.set_skips(skips)
x = skips[0]
x = self.decoder(x)
return self.last_block(x)