|
import functools |
|
|
|
import tensorflow as tf |
|
from tensorflow.keras import backend as K |
|
from tensorflow.keras import layers |
|
|
|
from .blocks.attentions import SAM |
|
from .blocks.bottleneck import BottleneckBlock |
|
from .blocks.misc_gating import CrossGatingBlock |
|
from .blocks.others import UpSampleRatio |
|
from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock |
|
from .layers import Resizing |
|
|
|
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") |
|
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") |
|
ConvT_up = functools.partial( |
|
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" |
|
) |
|
Conv_down = functools.partial( |
|
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" |
|
) |
|
|
|
|
|
def MAXIM( |
|
features: int = 64, |
|
depth: int = 3, |
|
num_stages: int = 2, |
|
num_groups: int = 1, |
|
use_bias: bool = True, |
|
num_supervision_scales: int = 1, |
|
lrelu_slope: float = 0.2, |
|
use_global_mlp: bool = True, |
|
use_cross_gating: bool = True, |
|
high_res_stages: int = 2, |
|
block_size_hr=(16, 16), |
|
block_size_lr=(8, 8), |
|
grid_size_hr=(16, 16), |
|
grid_size_lr=(8, 8), |
|
num_bottleneck_blocks: int = 1, |
|
block_gmlp_factor: int = 2, |
|
grid_gmlp_factor: int = 2, |
|
input_proj_factor: int = 2, |
|
channels_reduction: int = 4, |
|
num_outputs: int = 3, |
|
dropout_rate: float = 0.0, |
|
): |
|
"""The MAXIM model function with multi-stage and multi-scale supervision. |
|
|
|
For more model details, please check the CVPR paper: |
|
MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973) |
|
|
|
Attributes: |
|
features: initial hidden dimension for the input resolution. |
|
depth: the number of downsampling depth for the model. |
|
num_stages: how many stages to use. It will also affects the output list. |
|
num_groups: how many blocks each stage contains. |
|
use_bias: whether to use bias in all the conv/mlp layers. |
|
num_supervision_scales: the number of desired supervision scales. |
|
lrelu_slope: the negative slope parameter in leaky_relu layers. |
|
use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each |
|
layer. |
|
use_cross_gating: whether to use the cross-gating MLP block (CGB) in the |
|
skip connections and multi-stage feature fusion layers. |
|
high_res_stages: how many stages are specificied as high-res stages. The |
|
rest (depth - high_res_stages) are called low_res_stages. |
|
block_size_hr: the block_size parameter for high-res stages. |
|
block_size_lr: the block_size parameter for low-res stages. |
|
grid_size_hr: the grid_size parameter for high-res stages. |
|
grid_size_lr: the grid_size parameter for low-res stages. |
|
num_bottleneck_blocks: how many bottleneck blocks. |
|
block_gmlp_factor: the input projection factor for block_gMLP layers. |
|
grid_gmlp_factor: the input projection factor for grid_gMLP layers. |
|
input_proj_factor: the input projection factor for the MAB block. |
|
channels_reduction: the channel reduction factor for SE layer. |
|
num_outputs: the output channels. |
|
dropout_rate: Dropout rate. |
|
|
|
Returns: |
|
The output contains a list of arrays consisting of multi-stage multi-scale |
|
outputs. For example, if num_stages = num_supervision_scales = 3 (the |
|
model used in the paper), the output specs are: outputs = |
|
[[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3], |
|
[output_stage2_scale1, output_stage2_scale2, output_stage2_scale3], |
|
[output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],] |
|
The final output can be retrieved by outputs[-1][-1]. |
|
""" |
|
|
|
def apply(x): |
|
n, h, w, c = ( |
|
K.int_shape(x)[0], |
|
K.int_shape(x)[1], |
|
K.int_shape(x)[2], |
|
K.int_shape(x)[3], |
|
) |
|
|
|
shortcuts = [] |
|
shortcuts.append(x) |
|
|
|
|
|
for i in range(1, num_supervision_scales): |
|
resizing_layer = Resizing( |
|
height=h // (2 ** i), |
|
width=w // (2 ** i), |
|
method="nearest", |
|
antialias=True, |
|
name=f"initial_resizing_{K.get_uid('Resizing')}", |
|
) |
|
shortcuts.append(resizing_layer(x)) |
|
|
|
|
|
|
|
|
|
outputs_all = [] |
|
sam_features, encs_prev, decs_prev = [], [], [] |
|
|
|
for idx_stage in range(num_stages): |
|
|
|
x_scales = [] |
|
for i in range(num_supervision_scales): |
|
x_scale = Conv3x3( |
|
filters=(2 ** i) * features, |
|
use_bias=use_bias, |
|
name=f"stage_{idx_stage}_input_conv_{i}", |
|
)(shortcuts[i]) |
|
|
|
|
|
if idx_stage > 0: |
|
|
|
if use_cross_gating: |
|
block_size = ( |
|
block_size_hr if i < high_res_stages else block_size_lr |
|
) |
|
grid_size = grid_size_hr if i < high_res_stages else block_size_lr |
|
x_scale, _ = CrossGatingBlock( |
|
features=(2 ** i) * features, |
|
block_size=block_size, |
|
grid_size=grid_size, |
|
dropout_rate=dropout_rate, |
|
input_proj_factor=input_proj_factor, |
|
upsample_y=False, |
|
use_bias=use_bias, |
|
name=f"stage_{idx_stage}_input_fuse_sam_{i}", |
|
)(x_scale, sam_features.pop()) |
|
else: |
|
x_scale = Conv1x1( |
|
filters=(2 ** i) * features, |
|
use_bias=use_bias, |
|
name=f"stage_{idx_stage}_input_catconv_{i}", |
|
)(tf.concat([x_scale, sam_features.pop()], axis=-1)) |
|
|
|
x_scales.append(x_scale) |
|
|
|
|
|
encs = [] |
|
x = x_scales[0] |
|
|
|
for i in range(depth): |
|
|
|
block_size = block_size_hr if i < high_res_stages else block_size_lr |
|
grid_size = grid_size_hr if i < high_res_stages else block_size_lr |
|
use_cross_gating_layer = True if idx_stage > 0 else False |
|
|
|
|
|
x_scale = x_scales[i] if i < num_supervision_scales else None |
|
|
|
|
|
enc_prev = encs_prev.pop() if idx_stage > 0 else None |
|
dec_prev = decs_prev.pop() if idx_stage > 0 else None |
|
|
|
x, bridge = UNetEncoderBlock( |
|
num_channels=(2 ** i) * features, |
|
num_groups=num_groups, |
|
downsample=True, |
|
lrelu_slope=lrelu_slope, |
|
block_size=block_size, |
|
grid_size=grid_size, |
|
block_gmlp_factor=block_gmlp_factor, |
|
grid_gmlp_factor=grid_gmlp_factor, |
|
input_proj_factor=input_proj_factor, |
|
channels_reduction=channels_reduction, |
|
use_global_mlp=use_global_mlp, |
|
dropout_rate=dropout_rate, |
|
use_bias=use_bias, |
|
use_cross_gating=use_cross_gating_layer, |
|
name=f"stage_{idx_stage}_encoder_block_{i}", |
|
)(x, skip=x_scale, enc=enc_prev, dec=dec_prev) |
|
|
|
|
|
encs.append(bridge) |
|
|
|
|
|
for i in range(num_bottleneck_blocks): |
|
x = BottleneckBlock( |
|
block_size=block_size_lr, |
|
grid_size=block_size_lr, |
|
features=(2 ** (depth - 1)) * features, |
|
num_groups=num_groups, |
|
block_gmlp_factor=block_gmlp_factor, |
|
grid_gmlp_factor=grid_gmlp_factor, |
|
input_proj_factor=input_proj_factor, |
|
dropout_rate=dropout_rate, |
|
use_bias=use_bias, |
|
channels_reduction=channels_reduction, |
|
name=f"stage_{idx_stage}_global_block_{i}", |
|
)(x) |
|
|
|
global_feature = x |
|
|
|
|
|
skip_features = [] |
|
for i in reversed(range(depth)): |
|
|
|
block_size = block_size_hr if i < high_res_stages else block_size_lr |
|
grid_size = grid_size_hr if i < high_res_stages else block_size_lr |
|
|
|
|
|
signal = tf.concat( |
|
[ |
|
UpSampleRatio( |
|
num_channels=(2 ** i) * features, |
|
ratio=2 ** (j - i), |
|
use_bias=use_bias, |
|
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}", |
|
)(enc) |
|
for j, enc in enumerate(encs) |
|
], |
|
axis=-1, |
|
) |
|
|
|
|
|
if use_cross_gating: |
|
skips, global_feature = CrossGatingBlock( |
|
features=(2 ** i) * features, |
|
block_size=block_size, |
|
grid_size=grid_size, |
|
input_proj_factor=input_proj_factor, |
|
dropout_rate=dropout_rate, |
|
upsample_y=True, |
|
use_bias=use_bias, |
|
name=f"stage_{idx_stage}_cross_gating_block_{i}", |
|
)(signal, global_feature) |
|
else: |
|
skips = Conv1x1( |
|
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0" |
|
)(signal) |
|
skips = Conv3x3( |
|
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1" |
|
)(skips) |
|
|
|
skip_features.append(skips) |
|
|
|
|
|
outputs, decs, sam_features = [], [], [] |
|
for i in reversed(range(depth)): |
|
|
|
block_size = block_size_hr if i < high_res_stages else block_size_lr |
|
grid_size = grid_size_hr if i < high_res_stages else block_size_lr |
|
|
|
|
|
signal = tf.concat( |
|
[ |
|
UpSampleRatio( |
|
num_channels=(2 ** i) * features, |
|
ratio=2 ** (depth - j - 1 - i), |
|
use_bias=use_bias, |
|
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}", |
|
)(skip) |
|
for j, skip in enumerate(skip_features) |
|
], |
|
axis=-1, |
|
) |
|
|
|
|
|
x = UNetDecoderBlock( |
|
num_channels=(2 ** i) * features, |
|
num_groups=num_groups, |
|
lrelu_slope=lrelu_slope, |
|
block_size=block_size, |
|
grid_size=grid_size, |
|
block_gmlp_factor=block_gmlp_factor, |
|
grid_gmlp_factor=grid_gmlp_factor, |
|
input_proj_factor=input_proj_factor, |
|
channels_reduction=channels_reduction, |
|
use_global_mlp=use_global_mlp, |
|
dropout_rate=dropout_rate, |
|
use_bias=use_bias, |
|
name=f"stage_{idx_stage}_decoder_block_{i}", |
|
)(x, bridge=signal) |
|
|
|
|
|
decs.append(x) |
|
|
|
|
|
if i < num_supervision_scales: |
|
if idx_stage < num_stages - 1: |
|
sam, output = SAM( |
|
num_channels=(2 ** i) * features, |
|
output_channels=num_outputs, |
|
use_bias=use_bias, |
|
name=f"stage_{idx_stage}_supervised_attention_module_{i}", |
|
)(x, shortcuts[i]) |
|
outputs.append(output) |
|
sam_features.append(sam) |
|
else: |
|
output = Conv3x3( |
|
num_outputs, |
|
use_bias=use_bias, |
|
name=f"stage_{idx_stage}_output_conv_{i}", |
|
)(x) |
|
output = output + shortcuts[i] |
|
outputs.append(output) |
|
|
|
encs_prev = encs[::-1] |
|
decs_prev = decs |
|
|
|
|
|
outputs_all.append(outputs) |
|
return outputs_all |
|
|
|
return apply |
|
|