GraCo / isegm /model /is_plainvit_model.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
3.4 kB
import math
import torch.nn as nn
from isegm.utils.serialization import serialize
from .is_model import ISModel
from .modeling.models_vit import VisionTransformer, PatchEmbed
from .modeling.swin_transformer import SwinTransfomerSegHead
class SimpleFPN(nn.Module):
def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]):
super().__init__()
self.down_4_chan = max(out_dims[0]*2, in_dim // 2)
self.down_4 = nn.Sequential(
nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2),
nn.GroupNorm(1, self.down_4_chan),
nn.GELU(),
nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2),
nn.GroupNorm(1, self.down_4_chan // 2),
nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1),
nn.GroupNorm(1, out_dims[0]),
nn.GELU()
)
self.down_8_chan = max(out_dims[1], in_dim // 2)
self.down_8 = nn.Sequential(
nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2),
nn.GroupNorm(1, self.down_8_chan),
nn.Conv2d(self.down_8_chan, out_dims[1], 1),
nn.GroupNorm(1, out_dims[1]),
nn.GELU()
)
self.down_16 = nn.Sequential(
nn.Conv2d(in_dim, out_dims[2], 1),
nn.GroupNorm(1, out_dims[2]),
nn.GELU()
)
self.down_32_chan = max(out_dims[3], in_dim * 2)
self.down_32 = nn.Sequential(
nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2),
nn.GroupNorm(1, self.down_32_chan),
nn.Conv2d(self.down_32_chan, out_dims[3], 1),
nn.GroupNorm(1, out_dims[3]),
nn.GELU()
)
self.init_weights()
def init_weights(self):
pass
def forward(self, x):
x_down_4 = self.down_4(x)
x_down_8 = self.down_8(x)
x_down_16 = self.down_16(x)
x_down_32 = self.down_32(x)
return [x_down_4, x_down_8, x_down_16, x_down_32]
class PlainVitModel(ISModel):
@serialize
def __init__(
self,
backbone_params={},
neck_params={},
head_params={},
random_split=False,
**kwargs
):
super().__init__(**kwargs)
self.random_split = random_split
self.patch_embed_coords = PatchEmbed(
img_size= backbone_params['img_size'],
patch_size=backbone_params['patch_size'],
in_chans=3 if self.with_prev_mask else 2,
embed_dim=backbone_params['embed_dim'],
)
self.backbone = VisionTransformer(**backbone_params)
self.neck = SimpleFPN(**neck_params)
self.head = SwinTransfomerSegHead(**head_params)
def backbone_forward(self, image, coord_features=None, gra=None):
coord_features = self.patch_embed_coords(coord_features)
backbone_features = self.backbone.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split)
# Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
B, N, C = backbone_features.shape
grid_size = self.backbone.patch_embed.grid_size
backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1])
multi_scale_features = self.neck(backbone_features)
return {'instances': self.head(multi_scale_features), 'instances_aux': None}