File size: 4,090 Bytes
5c71fc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import functools

import tensorflow as tf
from tensorflow.keras import layers

from .others import MlpBlock

Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")


def CALayer(
    num_channels: int,
    reduction: int = 4,
    use_bias: bool = True,
    name: str = "channel_attention",
):
    """Squeeze-and-excitation block for channel attention.

    ref: https://arxiv.org/abs/1709.01507
    """

    def apply(x):
        # 2D global average pooling
        y = layers.GlobalAvgPool2D(keepdims=True)(x)
        # Squeeze (in Squeeze-Excitation)
        y = Conv1x1(
            filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0"
        )(y)
        y = tf.nn.relu(y)
        # Excitation (in Squeeze-Excitation)
        y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
        y = tf.nn.sigmoid(y)
        return x * y

    return apply


def RCAB(
    num_channels: int,
    reduction: int = 4,
    lrelu_slope: float = 0.2,
    use_bias: bool = True,
    name: str = "residual_ca",
):
    """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""

    def apply(x):
        shortcut = x
        x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
        x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
        x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
        x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
        x = CALayer(
            num_channels=num_channels,
            reduction=reduction,
            use_bias=use_bias,
            name=f"{name}_channel_attention",
        )(x)
        return x + shortcut

    return apply


def RDCAB(
    num_channels: int,
    reduction: int = 16,
    use_bias: bool = True,
    dropout_rate: float = 0.0,
    name: str = "rdcab",
):
    """Residual dense channel attention block. Used in Bottlenecks."""

    def apply(x):
        y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
        y = MlpBlock(
            mlp_dim=num_channels,
            dropout_rate=dropout_rate,
            use_bias=use_bias,
            name=f"{name}_channel_mixing",
        )(y)
        y = CALayer(
            num_channels=num_channels,
            reduction=reduction,
            use_bias=use_bias,
            name=f"{name}_channel_attention",
        )(y)
        x = x + y
        return x

    return apply


def SAM(
    num_channels: int,
    output_channels: int = 3,
    use_bias: bool = True,
    name: str = "sam",
):

    """Supervised attention module for multi-stage training.

    Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
    """

    def apply(x, x_image):
        """Apply the SAM module to the input and num_channels.
        Args:
          x: the output num_channels from UNet decoder with shape (h, w, c)
          x_image: the input image with shape (h, w, 3)
        Returns:
          A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
            next stage, and (image) is the output restored image at current stage.
        """
        # Get num_channels
        x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)

        # Output restored image X_s
        if output_channels == 3:
            image = (
                Conv3x3(
                    filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
                )(x)
                + x_image
            )
        else:
            image = Conv3x3(
                filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
            )(x)

        # Get attention maps for num_channels
        x2 = tf.nn.sigmoid(
            Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
        )

        # Get attended feature maps
        x1 = x1 * x2

        # Residual connection
        x1 = x1 + x
        return x1, image

    return apply