GraCo / isegm /model /modeling /segformer.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
18 kB
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
constant_init, normal_init, trunc_normal_init)
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint
from .transformer_helper import PatchEmbed, nchw_to_nlc, nlc_to_nchw, resize, \
get_root_logger, BaseDecodeHead, HEADS, BACKBONES
class MixFFN(BaseModule):
"""An implementation of MixFFN of Segformer.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
dropout_layer=None,
init_cfg=None):
super(MixFFN, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
in_channels = embed_dims
fc1 = Conv2d(
in_channels=in_channels,
out_channels=feedforward_channels,
kernel_size=1,
stride=1,
bias=True)
# 3x3 depth wise conv to provide positional encode information
pe_conv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=(3 - 1) // 2,
bias=True,
groups=feedforward_channels)
fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=in_channels,
kernel_size=1,
stride=1,
bias=True)
drop = nn.Dropout(ffn_drop)
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
def forward(self, x, hw_shape, identity=None):
out = nlc_to_nchw(x, hw_shape)
out = self.layers(out)
out = nchw_to_nlc(out)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
class EfficientMultiheadAttention(MultiheadAttention):
"""An implementation of Efficient Multi-head Attention of Segformer.
This module is modified from MultiheadAttention which is a module from
mmcv.cnn.bricks.transformer.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut. Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: False.
qkv_bias (bool): enable bias for qkv if True. Default True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
dropout_layer=None,
init_cfg=None,
batch_first=True,
qkv_bias=False,
norm_cfg=dict(type='LN'),
sr_ratio=1):
super().__init__(
embed_dims,
num_heads,
attn_drop,
proj_drop,
dropout_layer=dropout_layer,
init_cfg=init_cfg,
batch_first=batch_first,
bias=qkv_bias)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=sr_ratio,
stride=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self, x, hw_shape, identity=None):
x_q = x
if self.sr_ratio > 1:
x_kv = nlc_to_nchw(x, hw_shape)
x_kv = self.sr(x_kv)
x_kv = nchw_to_nlc(x_kv)
x_kv = self.norm(x_kv)
else:
x_kv = x
if identity is None:
identity = x_q
# `need_weights=True` will let nn.MultiHeadAttention
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
# the error that large scale tensor sum operation may cause cuda error.
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
return identity + self.dropout_layer(self.proj_drop(out))
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Segformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
after the feed forward layer. Default 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.0.
qkv_bias (bool): enable bias for qkv if True.
Default: True.
act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default: False.
init_cfg (dict, optional): Initialization config dict.
Default:None.
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
Attention of Segformer. Default: 1.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
sr_ratio=1):
super(TransformerEncoderLayer, self).__init__()
# The ret[0] of build_norm_layer is norm name.
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = EfficientMultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
batch_first=batch_first,
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
sr_ratio=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = MixFFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
def forward(self, x, hw_shape):
x = self.attn(self.norm1(x), hw_shape, identity=x)
x = self.ffn(self.norm2(x), hw_shape, identity=x)
return x
@BACKBONES.register_module()
class MixVisionTransformer(BaseModule):
"""The backbone of Segformer.
A PyTorch implement of : `SegFormer: Simple and Efficient Design for
Semantic Segmentation with Transformers` -
https://arxiv.org/pdf/2105.15203.pdf
Args:
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): Embedding dimension. Default: 768.
num_stags (int): The num of stages. Default: 4.
num_layers (Sequence[int]): The layer number of each transformer encode
layer. Default: [3, 4, 6, 3].
num_heads (Sequence[int]): The attention heads of each transformer
encode layer. Default: [1, 2, 4, 8].
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
embedding. Default: [7, 3, 3, 3].
strides (Sequence[int]): The stride of each overlapped patch embedding.
Default: [4, 2, 2, 2].
sr_ratios (Sequence[int]): The spatial reduction rate of each
transformer encode layer. Default: [8, 4, 2, 1].
out_indices (Sequence[int] | int): Output from which stages.
Default: (0, 1, 2, 3).
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=64,
embed_dims=64,
num_stages=4,
num_layers=[3, 4, 6, 3],
num_heads=[1, 2, 4, 8],
patch_sizes=[7, 3, 3, 3],
strides=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
out_indices=(0, 1, 2, 3),
mlp_ratio=4,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN', eps=1e-6),
pretrain_style='official',
pretrained=None,
init_cfg=None):
super().__init__()
assert pretrain_style in [
'official', 'mmcls'
], 'we only support official weights or mmcls weights.'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
else:
raise TypeError('pretrained must be a str or None')
self.embed_dims = embed_dims
self.num_stages = num_stages
self.num_layers = num_layers
self.num_heads = num_heads
self.patch_sizes = patch_sizes
self.strides = strides
self.sr_ratios = sr_ratios
assert num_stages == len(num_layers) == len(num_heads) \
== len(patch_sizes) == len(strides) == len(sr_ratios)
self.out_indices = out_indices
assert max(out_indices) < self.num_stages
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
# transformer encoder
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
] # stochastic num_layer decay rule
cur = 0
self.layers = ModuleList()
for i, num_layer in enumerate(num_layers):
embed_dims_i = embed_dims * num_heads[i]
patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims_i,
kernel_size=patch_sizes[i],
stride=strides[i],
padding=patch_sizes[i] // 2,
pad_to_patch_size=False,
norm_cfg=norm_cfg)
layer = ModuleList([
TransformerEncoderLayer(
embed_dims=embed_dims_i,
num_heads=num_heads[i],
feedforward_channels=mlp_ratio * embed_dims_i,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[cur + idx],
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
])
in_channels = embed_dims_i
# The ret[0] of build_norm_layer is norm name.
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
self.layers.append(ModuleList([patch_embed, layer, norm]))
cur += num_layer
def init_weights(self):
if self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# only use this code if when adopt v3
ori_proj_weight = state_dict['layers.0.0.projection.weight']
state_dict['layers.0.0.projection.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1)
self.load_state_dict(state_dict, True)
def forward(self, x, additional_features=None):
outs = []
for i, layer in enumerate(self.layers):
x, H, W = layer[0](x), layer[0].DH, layer[0].DW
hw_shape = (H, W)
for block in layer[1]:
x = block(x, hw_shape)
x = layer[2](x)
x = nlc_to_nchw(x, hw_shape)
if i in self.out_indices:
outs.append(x)
return outs
@HEADS.register_module()
class SegformerHead(BaseDecodeHead):
"""The all mlp Head of segformer.
This head is the implementation of
`Segformer <https://arxiv.org/abs/2105.15203>` _.
Args:
interpolate_mode: The interpolate mode of MLP head upsample operation.
Default: 'bilinear'.
"""
def __init__(self, interpolate_mode='bilinear', **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.interpolate_mode = interpolate_mode
num_inputs = len(self.in_channels)
assert num_inputs == len(self.in_index)
self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=self.in_channels[i],
out_channels=self.channels,
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.fusion_conv = ConvModule(
in_channels=self.channels * num_inputs,
out_channels=self.channels,
kernel_size=1,
norm_cfg=self.norm_cfg)
def forward(self, inputs):
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
inputs = self._transform_inputs(inputs)
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(
resize(
input=conv(x),
size=inputs[0].shape[2:],
mode=self.interpolate_mode,
align_corners=self.align_corners))
out = self.fusion_conv(torch.cat(outs, dim=1))
out = self.cls_seg(out)
return out