rap-sam / app /models /necks /ramsam_neck.py
HarborYuan's picture
add rap_sam
502989e
raw
history blame
No virus
8.01 kB
import math
import torch
from torch import nn
import torch.nn.functional as F
from mmengine.model import kaiming_init
from mmdet.registry import MODELS
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
class DeformLayer(nn.Module):
def __init__(self,
in_planes,
out_planes,
deconv_kernel=4,
deconv_stride=2,
deconv_pad=1,
deconv_out_pad=0,
modulate_deform=True,
num_groups=1,
deform_num_groups=1,
dilation=1):
super(DeformLayer, self).__init__()
self.deform_modulated = modulate_deform
if modulate_deform:
deform_conv_op = ModulatedDeformConv2d
offset_channels = 27
else:
deform_conv_op = DeformConv2d
offset_channels = 18
self.dcn_offset = nn.Conv2d(in_planes, offset_channels * deform_num_groups, kernel_size=3, stride=1, padding=1 * dilation, dilation=dilation)
self.dcn = deform_conv_op(in_planes, out_planes, kernel_size=3, stride=1, padding=1 * dilation, bias=False, groups=num_groups, dilation=dilation, deformable_groups=deform_num_groups)
for layer in [self.dcn]:
kaiming_init(layer)
nn.init.constant_(self.dcn_offset.weight, 0)
nn.init.constant_(self.dcn_offset.bias, 0)
# nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) #
self.dcn_bn = nn.SyncBatchNorm(out_planes)
self.up_sample = nn.ConvTranspose2d(in_channels=out_planes, out_channels=out_planes, kernel_size=deconv_kernel, stride=deconv_stride, padding=deconv_pad, output_padding=deconv_out_pad, bias=False)
self._deconv_init()
# nn.GroupNorm(64, out_planes) # nn.BatchNorm2d(out_planes) #
self.up_bn = nn.SyncBatchNorm(out_planes)
self.relu = nn.ReLU()
def forward(self, x):
out = x
if self.deform_modulated:
offset_mask = self.dcn_offset(out)
offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
offset = torch.cat((offset_x, offset_y), dim=1)
mask = mask.sigmoid()
out = self.dcn(out, offset, mask)
else:
offset = self.dcn_offset(out)
out = self.dcn(out, offset)
x = out
x = self.dcn_bn(x)
x = self.relu(x)
x = self.up_sample(x)
x = self.up_bn(x)
x = self.relu(x)
return x
def _deconv_init(self):
w = self.up_sample.weight.data
f = math.ceil(w.size(2) / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(w.size(2)):
for j in range(w.size(3)):
w[0, 0, i, j] = \
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, w.size(0)):
w[c, 0, :, :] = w[0, 0, :, :]
class LiteDeformConv(nn.Module):
def __init__(self, agg_dim, backbone_shape):
super(LiteDeformConv, self).__init__()
in_channels = []
out_channels = [agg_dim]
for feat in backbone_shape:
in_channels.append(feat)
out_channels.append(feat//2)
self.lateral_conv0 = nn.Conv2d(in_channels=in_channels[-1], out_channels=out_channels[-1], kernel_size=1, stride=1, padding=0)
self.deform_conv1 = DeformLayer(in_planes=out_channels[-1], out_planes=out_channels[-2])
self.lateral_conv1 = nn.Conv2d(in_channels=in_channels[-2], out_channels=out_channels[-2], kernel_size=1, stride=1, padding=0)
self.deform_conv2 = DeformLayer(in_planes=out_channels[-2], out_planes=out_channels[-3])
self.lateral_conv2 = nn.Conv2d(in_channels=in_channels[-3], out_channels=out_channels[-3], kernel_size=1, stride=1, padding=0)
self.deform_conv3 = DeformLayer(in_planes=out_channels[-3], out_planes=out_channels[-4])
self.lateral_conv3 = nn.Conv2d(in_channels=in_channels[-4], out_channels=out_channels[-4], kernel_size=1, stride=1, padding=0)
# self.fuse_conv = nn.Conv2d(in_channels=sum(out_channels[1:]), out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1)
self.output_conv = nn.Conv2d(in_channels=out_channels[-5], out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1)
self.bias = nn.Parameter(torch.FloatTensor(1,out_channels[-5],1,1), requires_grad=True)
self.bias.data.fill_(0.0)
self.conv_a5 = nn.Conv2d(in_channels=out_channels[-1], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
self.conv_a4 = nn.Conv2d(in_channels=out_channels[-2], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
self.conv_a3 = nn.Conv2d(in_channels=out_channels[-3], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
self.conv_a2 = nn.Conv2d(in_channels=out_channels[-4], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
def forward(self, features_list):
p5 = self.lateral_conv0(features_list[-1])
x5 = p5
x = self.deform_conv1(x5)
p4 = self.lateral_conv1(features_list[-2])
x4 = p4 + x
x = self.deform_conv2(x4)
p3 = self.lateral_conv2(features_list[-3])
x3 = p3 + x
x = self.deform_conv3(x3)
p2 = self.lateral_conv3(features_list[-4])
x2 = p2 + x
# CFA
x5 = self.conv_a5(x5)
x4 = self.conv_a4(x4)
x3 = self.conv_a3(x3)
_x5 = F.interpolate(x5, scale_factor=8, align_corners=False, mode='bilinear')
_x4 = F.interpolate(x4, scale_factor=4, align_corners=False, mode='bilinear')
_x3 = F.interpolate(x3, scale_factor=2, align_corners=False, mode='bilinear')
x2 = self.conv_a2(x2)
x = _x5 + _x4 + _x3 + x2 + self.bias
x = self.output_conv(x)
return x, (x5, x4, x3)
@MODELS.register_module()
class YOSONeck(nn.Module):
def __init__(self,
agg_dim,
hidden_dim,
backbone_shape,
return_multi_scale=False,
return_single_scale=False,
#Just for compatibility with Mask2Former, not actually used
in_channels=None,
feat_channels=None,
out_channels=None
):
super().__init__()
# in_channels == backbone_shape
# hidden_dim == feat_channels == out_channels == 256
self.return_single_scale = return_single_scale
self.return_multi_scale = return_multi_scale
self.deconv = LiteDeformConv(agg_dim=agg_dim, backbone_shape=backbone_shape)
self.loc_conv = nn.Conv2d(in_channels=agg_dim + 2, out_channels=hidden_dim, kernel_size=1, stride=1)
self.init_weights()
def init_weights(self) -> None:
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def generate_coord(self, input_feat):
x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device)
y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([input_feat.shape[0], 1, -1, -1])
x = x.expand([input_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
return coord_feat
def forward(self,
features_list,
batch_img_metas = None,
num_frames = None):
features, multi_scale = self.deconv(features_list)
coord_feat = self.generate_coord(features)
features = torch.cat([features, coord_feat], 1)
features = self.loc_conv(features)
if self.return_single_scale: # maskformer
return features, multi_scale[0]
if self.return_multi_scale: # mask2former
return features, multi_scale
return features