|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import print_function |
|
from __future__ import division |
|
|
|
import warnings |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.nn.init import xavier_uniform_, constant_ |
|
from .dcnv3_func import dcnv3_core_pytorch |
|
|
|
|
|
class to_channels_first(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x.permute(0, 3, 1, 2) |
|
|
|
|
|
class to_channels_last(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x.permute(0, 2, 3, 1) |
|
|
|
|
|
def build_norm_layer(dim, |
|
norm_layer, |
|
in_format='channels_last', |
|
out_format='channels_last', |
|
eps=1e-6): |
|
layers = [] |
|
if norm_layer == 'BN': |
|
if in_format == 'channels_last': |
|
layers.append(to_channels_first()) |
|
layers.append(nn.BatchNorm2d(dim)) |
|
if out_format == 'channels_last': |
|
layers.append(to_channels_last()) |
|
elif norm_layer == 'LN': |
|
if in_format == 'channels_first': |
|
layers.append(to_channels_last()) |
|
layers.append(nn.LayerNorm(dim, eps=eps)) |
|
if out_format == 'channels_first': |
|
layers.append(to_channels_first()) |
|
else: |
|
raise NotImplementedError( |
|
f'build_norm_layer does not support {norm_layer}') |
|
return nn.Sequential(*layers) |
|
|
|
|
|
def build_act_layer(act_layer): |
|
if act_layer == 'ReLU': |
|
return nn.ReLU(inplace=True) |
|
elif act_layer == 'SiLU': |
|
return nn.SiLU(inplace=True) |
|
elif act_layer == 'GELU': |
|
return nn.GELU() |
|
|
|
raise NotImplementedError(f'build_act_layer does not support {act_layer}') |
|
|
|
|
|
def _is_power_of_2(n): |
|
if (not isinstance(n, int)) or (n < 0): |
|
raise ValueError( |
|
"invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) |
|
|
|
return (n & (n-1) == 0) and n != 0 |
|
|
|
|
|
class DCNv3_pytorch(nn.Module): |
|
def __init__( |
|
self, channels=64, kernel_size=3, stride=1, |
|
pad=1, dilation=1, group=4, offset_scale=1.0, |
|
act_layer='GELU', norm_layer='LN'): |
|
""" |
|
DCNv3 Module |
|
:param channels |
|
:param kernel_size |
|
:param stride |
|
:param pad |
|
:param dilation |
|
:param group |
|
:param offset_scale |
|
:param act_layer |
|
:param norm_layer |
|
""" |
|
super().__init__() |
|
if channels % group != 0: |
|
raise ValueError( |
|
f'channels must be divisible by group, but got {channels} and {group}') |
|
_d_per_group = channels // group |
|
|
|
if not _is_power_of_2(_d_per_group): |
|
warnings.warn( |
|
"You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " |
|
"which is more efficient in our CUDA implementation.") |
|
|
|
self.offset_scale = offset_scale |
|
self.channels = channels |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.dilation = 1 |
|
self.pad = pad |
|
self.group = group |
|
self.group_channels = channels // group |
|
self.offset_scale = offset_scale |
|
|
|
self.dw_conv = nn.Sequential( |
|
nn.Conv2d( |
|
channels, |
|
channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
padding=(kernel_size-1)//2, |
|
groups=channels), |
|
build_norm_layer( |
|
channels, |
|
norm_layer, |
|
'channels_first', |
|
'channels_last'), |
|
build_act_layer(act_layer)) |
|
self.offset = nn.Linear( |
|
channels, |
|
group * kernel_size * kernel_size * 2) |
|
self.mask = nn.Linear( |
|
channels, |
|
group * kernel_size * kernel_size) |
|
self.input_proj = nn.Linear(channels, channels) |
|
self.output_proj = nn.Linear(channels, channels) |
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
constant_(self.offset.weight.data, 0.) |
|
constant_(self.offset.bias.data, 0.) |
|
constant_(self.mask.weight.data, 0.) |
|
constant_(self.mask.bias.data, 0.) |
|
xavier_uniform_(self.input_proj.weight.data) |
|
constant_(self.input_proj.bias.data, 0.) |
|
xavier_uniform_(self.output_proj.weight.data) |
|
constant_(self.output_proj.bias.data, 0.) |
|
|
|
def forward(self, input): |
|
""" |
|
:param query (N, H, W, C) |
|
:return output (N, H, W, C) |
|
""" |
|
N, H, W, _ = input.shape |
|
|
|
x = self.input_proj(input) |
|
|
|
x1 = input.permute(0, 3, 1, 2) |
|
x1 = self.dw_conv(x1) |
|
offset = self.offset(x1) |
|
mask = self.mask(x1).reshape(N, H, W, self.group, -1) |
|
mask = F.softmax(mask, -1).reshape(N, H, W, -1) |
|
|
|
x = dcnv3_core_pytorch( |
|
x, offset, mask, |
|
self.kernel_size, self.kernel_size, |
|
self.stride, self.stride, |
|
self.pad, self.pad, |
|
self.dilation, self.dilation, |
|
self.group, self.group_channels, |
|
self.offset_scale) |
|
x = self.output_proj(x) |
|
|
|
return x |
|
|