import math import warnings import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer, constant_init, normal_init, trunc_normal_init) from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import MultiheadAttention from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint from .transformer_helper import PatchEmbed, nchw_to_nlc, nlc_to_nchw, resize, \ get_root_logger, BaseDecodeHead, HEADS, BACKBONES class MixFFN(BaseModule): """An implementation of MixFFN of Segformer. The differences between MixFFN & FFN: 1. Use 1X1 Conv to replace Linear layer. 2. Introduce 3X3 Conv to encode positional information. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. Defaults: 256. feedforward_channels (int): The hidden dimension of FFNs. Defaults: 1024. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='ReLU') ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Default 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, feedforward_channels, act_cfg=dict(type='GELU'), ffn_drop=0., dropout_layer=None, init_cfg=None): super(MixFFN, self).__init__(init_cfg) self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.act_cfg = act_cfg self.activate = build_activation_layer(act_cfg) in_channels = embed_dims fc1 = Conv2d( in_channels=in_channels, out_channels=feedforward_channels, kernel_size=1, stride=1, bias=True) # 3x3 depth wise conv to provide positional encode information pe_conv = Conv2d( in_channels=feedforward_channels, out_channels=feedforward_channels, kernel_size=3, stride=1, padding=(3 - 1) // 2, bias=True, groups=feedforward_channels) fc2 = Conv2d( in_channels=feedforward_channels, out_channels=in_channels, kernel_size=1, stride=1, bias=True) drop = nn.Dropout(ffn_drop) layers = [fc1, pe_conv, self.activate, drop, fc2, drop] self.layers = Sequential(*layers) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() def forward(self, x, hw_shape, identity=None): out = nlc_to_nchw(x, hw_shape) out = self.layers(out) out = nchw_to_nlc(out) if identity is None: identity = x return identity + self.dropout_layer(out) class EfficientMultiheadAttention(MultiheadAttention): """An implementation of Efficient Multi-head Attention of Segformer. This module is modified from MultiheadAttention which is a module from mmcv.cnn.bricks.transformer. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. Default: None. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default: False. qkv_bias (bool): enable bias for qkv if True. Default True. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head Attention of Segformer. Default: 1. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=None, init_cfg=None, batch_first=True, qkv_bias=False, norm_cfg=dict(type='LN'), sr_ratio=1): super().__init__( embed_dims, num_heads, attn_drop, proj_drop, dropout_layer=dropout_layer, init_cfg=init_cfg, batch_first=batch_first, bias=qkv_bias) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = Conv2d( in_channels=embed_dims, out_channels=embed_dims, kernel_size=sr_ratio, stride=sr_ratio) # The ret[0] of build_norm_layer is norm name. self.norm = build_norm_layer(norm_cfg, embed_dims)[1] def forward(self, x, hw_shape, identity=None): x_q = x if self.sr_ratio > 1: x_kv = nlc_to_nchw(x, hw_shape) x_kv = self.sr(x_kv) x_kv = nchw_to_nlc(x_kv) x_kv = self.norm(x_kv) else: x_kv = x if identity is None: identity = x_q # `need_weights=True` will let nn.MultiHeadAttention # `return attn_output, attn_output_weights.sum(dim=1) / num_heads` # The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set # `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`. # This issue - `https://github.com/pytorch/pytorch/issues/37583` report # the error that large scale tensor sum operation may cause cuda error. out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] return identity + self.dropout_layer(self.proj_drop(out)) class TransformerEncoderLayer(BaseModule): """Implements one encoder layer in Segformer. Args: embed_dims (int): The feature dimension. num_heads (int): Parallel attention heads. feedforward_channels (int): The hidden dimension for FFNs. drop_rate (float): Probability of an element to be zeroed. after the feed forward layer. Default 0.0. attn_drop_rate (float): The drop out rate for attention layer. Default 0.0. drop_path_rate (float): stochastic depth rate. Default 0.0. qkv_bias (bool): enable bias for qkv if True. Default: True. act_cfg (dict): The activation config for FFNs. Defalut: dict(type='GELU'). norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN'). batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default: False. init_cfg (dict, optional): Initialization config dict. Default:None. sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head Attention of Segformer. Default: 1. """ def __init__(self, embed_dims, num_heads, feedforward_channels, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., qkv_bias=True, act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN'), batch_first=True, sr_ratio=1): super(TransformerEncoderLayer, self).__init__() # The ret[0] of build_norm_layer is norm name. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] self.attn = EfficientMultiheadAttention( embed_dims=embed_dims, num_heads=num_heads, attn_drop=attn_drop_rate, proj_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), batch_first=batch_first, qkv_bias=qkv_bias, norm_cfg=norm_cfg, sr_ratio=sr_ratio) # The ret[0] of build_norm_layer is norm name. self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] self.ffn = MixFFN( embed_dims=embed_dims, feedforward_channels=feedforward_channels, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) def forward(self, x, hw_shape): x = self.attn(self.norm1(x), hw_shape, identity=x) x = self.ffn(self.norm2(x), hw_shape, identity=x) return x @BACKBONES.register_module() class MixVisionTransformer(BaseModule): """The backbone of Segformer. A PyTorch implement of : `SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers` - https://arxiv.org/pdf/2105.15203.pdf Args: in_channels (int): Number of input channels. Default: 3. embed_dims (int): Embedding dimension. Default: 768. num_stags (int): The num of stages. Default: 4. num_layers (Sequence[int]): The layer number of each transformer encode layer. Default: [3, 4, 6, 3]. num_heads (Sequence[int]): The attention heads of each transformer encode layer. Default: [1, 2, 4, 8]. patch_sizes (Sequence[int]): The patch_size of each overlapped patch embedding. Default: [7, 3, 3, 3]. strides (Sequence[int]): The stride of each overlapped patch embedding. Default: [4, 2, 2, 2]. sr_ratios (Sequence[int]): The spatial reduction rate of each transformer encode layer. Default: [8, 4, 2, 1]. out_indices (Sequence[int] | int): Output from which stages. Default: (0, 1, 2, 3). mlp_ratio (int): ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): Enable bias for qkv if True. Default: True. drop_rate (float): Probability of an element to be zeroed. Default 0.0 attn_drop_rate (float): The drop out rate for attention layer. Default 0.0 drop_path_rate (float): stochastic depth rate. Default 0.0 norm_cfg (dict): Config dict for normalization layer. Default: dict(type='LN') act_cfg (dict): The activation config for FFNs. Defalut: dict(type='GELU'). pretrain_style (str): Choose to use official or mmcls pretrain weights. Default: official. pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels=64, embed_dims=64, num_stages=4, num_layers=[3, 4, 6, 3], num_heads=[1, 2, 4, 8], patch_sizes=[7, 3, 3, 3], strides=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], out_indices=(0, 1, 2, 3), mlp_ratio=4, qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., act_cfg=dict(type='GELU'), norm_cfg=dict(type='LN', eps=1e-6), pretrain_style='official', pretrained=None, init_cfg=None): super().__init__() assert pretrain_style in [ 'official', 'mmcls' ], 'we only support official weights or mmcls weights.' if isinstance(pretrained, str) or pretrained is None: warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 'please use "init_cfg" instead') else: raise TypeError('pretrained must be a str or None') self.embed_dims = embed_dims self.num_stages = num_stages self.num_layers = num_layers self.num_heads = num_heads self.patch_sizes = patch_sizes self.strides = strides self.sr_ratios = sr_ratios assert num_stages == len(num_layers) == len(num_heads) \ == len(patch_sizes) == len(strides) == len(sr_ratios) self.out_indices = out_indices assert max(out_indices) < self.num_stages self.pretrain_style = pretrain_style self.pretrained = pretrained self.init_cfg = init_cfg # transformer encoder dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers)) ] # stochastic num_layer decay rule cur = 0 self.layers = ModuleList() for i, num_layer in enumerate(num_layers): embed_dims_i = embed_dims * num_heads[i] patch_embed = PatchEmbed( in_channels=in_channels, embed_dims=embed_dims_i, kernel_size=patch_sizes[i], stride=strides[i], padding=patch_sizes[i] // 2, pad_to_patch_size=False, norm_cfg=norm_cfg) layer = ModuleList([ TransformerEncoderLayer( embed_dims=embed_dims_i, num_heads=num_heads[i], feedforward_channels=mlp_ratio * embed_dims_i, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dpr[cur + idx], qkv_bias=qkv_bias, act_cfg=act_cfg, norm_cfg=norm_cfg, sr_ratio=sr_ratios[i]) for idx in range(num_layer) ]) in_channels = embed_dims_i # The ret[0] of build_norm_layer is norm name. norm = build_norm_layer(norm_cfg, embed_dims_i)[1] self.layers.append(ModuleList([patch_embed, layer, norm])) cur += num_layer def init_weights(self): if self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_init(m.weight, std=.02) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) if m.bias is not None: constant_init(m.bias, 0) elif isinstance(self.pretrained, str): logger = get_root_logger() checkpoint = _load_checkpoint( self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # only use this code if when adopt v3 ori_proj_weight = state_dict['layers.0.0.projection.weight'] state_dict['layers.0.0.projection.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1) self.load_state_dict(state_dict, True) def forward(self, x, additional_features=None): outs = [] for i, layer in enumerate(self.layers): x, H, W = layer[0](x), layer[0].DH, layer[0].DW hw_shape = (H, W) for block in layer[1]: x = block(x, hw_shape) x = layer[2](x) x = nlc_to_nchw(x, hw_shape) if i in self.out_indices: outs.append(x) return outs @HEADS.register_module() class SegformerHead(BaseDecodeHead): """The all mlp Head of segformer. This head is the implementation of `Segformer ` _. Args: interpolate_mode: The interpolate mode of MLP head upsample operation. Default: 'bilinear'. """ def __init__(self, interpolate_mode='bilinear', **kwargs): super().__init__(input_transform='multiple_select', **kwargs) self.interpolate_mode = interpolate_mode num_inputs = len(self.in_channels) assert num_inputs == len(self.in_index) self.convs = nn.ModuleList() for i in range(num_inputs): self.convs.append( ConvModule( in_channels=self.in_channels[i], out_channels=self.channels, kernel_size=1, stride=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.fusion_conv = ConvModule( in_channels=self.channels * num_inputs, out_channels=self.channels, kernel_size=1, norm_cfg=self.norm_cfg) def forward(self, inputs): # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 inputs = self._transform_inputs(inputs) outs = [] for idx in range(len(inputs)): x = inputs[idx] conv = self.convs[idx] outs.append( resize( input=conv(x), size=inputs[0].shape[2:], mode=self.interpolate_mode, align_corners=self.align_corners)) out = self.fusion_conv(torch.cat(outs, dim=1)) out = self.cls_seg(out) return out