import math import torch.nn as nn from isegm.utils.serialization import serialize from .is_model import ISModel from .modeling.models_vit import VisionTransformer, PatchEmbed from .modeling.swin_transformer import SwinTransfomerSegHead class SimpleFPN(nn.Module): def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]): super().__init__() self.down_4_chan = max(out_dims[0]*2, in_dim // 2) self.down_4 = nn.Sequential( nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2), nn.GroupNorm(1, self.down_4_chan), nn.GELU(), nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2), nn.GroupNorm(1, self.down_4_chan // 2), nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1), nn.GroupNorm(1, out_dims[0]), nn.GELU() ) self.down_8_chan = max(out_dims[1], in_dim // 2) self.down_8 = nn.Sequential( nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2), nn.GroupNorm(1, self.down_8_chan), nn.Conv2d(self.down_8_chan, out_dims[1], 1), nn.GroupNorm(1, out_dims[1]), nn.GELU() ) self.down_16 = nn.Sequential( nn.Conv2d(in_dim, out_dims[2], 1), nn.GroupNorm(1, out_dims[2]), nn.GELU() ) self.down_32_chan = max(out_dims[3], in_dim * 2) self.down_32 = nn.Sequential( nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2), nn.GroupNorm(1, self.down_32_chan), nn.Conv2d(self.down_32_chan, out_dims[3], 1), nn.GroupNorm(1, out_dims[3]), nn.GELU() ) self.init_weights() def init_weights(self): pass def forward(self, x): x_down_4 = self.down_4(x) x_down_8 = self.down_8(x) x_down_16 = self.down_16(x) x_down_32 = self.down_32(x) return [x_down_4, x_down_8, x_down_16, x_down_32] class PlainVitModel(ISModel): @serialize def __init__( self, backbone_params={}, neck_params={}, head_params={}, random_split=False, **kwargs ): super().__init__(**kwargs) self.random_split = random_split self.patch_embed_coords = PatchEmbed( img_size= backbone_params['img_size'], patch_size=backbone_params['patch_size'], in_chans=3 if self.with_prev_mask else 2, embed_dim=backbone_params['embed_dim'], ) self.backbone = VisionTransformer(**backbone_params) self.neck = SimpleFPN(**neck_params) self.head = SwinTransfomerSegHead(**head_params) def backbone_forward(self, image, coord_features=None, gra=None): coord_features = self.patch_embed_coords(coord_features) backbone_features = self.backbone.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split) # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 B, N, C = backbone_features.shape grid_size = self.backbone.patch_embed.grid_size backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) multi_scale_features = self.neck(backbone_features) return {'instances': self.head(multi_scale_features), 'instances_aux': None}