Spaces:
Running
Running
import math | |
import torch.nn as nn | |
from isegm.utils.serialization import serialize | |
from .is_model import ISModel | |
from .modeling.models_vit_lora import VisionTransformer_lora, PatchEmbed | |
from .modeling.swin_transformer import SwinTransfomerSegHead | |
class SimpleFPN(nn.Module): | |
def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]): | |
super().__init__() | |
self.down_4_chan = max(out_dims[0]*2, in_dim // 2) | |
self.down_4 = nn.Sequential( | |
nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2), | |
nn.GroupNorm(1, self.down_4_chan), | |
nn.GELU(), | |
nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2), | |
nn.GroupNorm(1, self.down_4_chan // 2), | |
nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1), | |
nn.GroupNorm(1, out_dims[0]), | |
nn.GELU() | |
) | |
self.down_8_chan = max(out_dims[1], in_dim // 2) | |
self.down_8 = nn.Sequential( | |
nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2), | |
nn.GroupNorm(1, self.down_8_chan), | |
nn.Conv2d(self.down_8_chan, out_dims[1], 1), | |
nn.GroupNorm(1, out_dims[1]), | |
nn.GELU() | |
) | |
self.down_16 = nn.Sequential( | |
nn.Conv2d(in_dim, out_dims[2], 1), | |
nn.GroupNorm(1, out_dims[2]), | |
nn.GELU() | |
) | |
self.down_32_chan = max(out_dims[3], in_dim * 2) | |
self.down_32 = nn.Sequential( | |
nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2), | |
nn.GroupNorm(1, self.down_32_chan), | |
nn.Conv2d(self.down_32_chan, out_dims[3], 1), | |
nn.GroupNorm(1, out_dims[3]), | |
nn.GELU() | |
) | |
self.init_weights() | |
def init_weights(self): | |
pass | |
def forward(self, x): | |
x_down_4 = self.down_4(x) | |
x_down_8 = self.down_8(x) | |
x_down_16 = self.down_16(x) | |
x_down_32 = self.down_32(x) | |
return [x_down_4, x_down_8, x_down_16, x_down_32] | |
class PlainVitModel_lora(ISModel): | |
def __init__( | |
self, | |
backbone_params={}, | |
neck_params={}, | |
head_params={}, | |
random_split=False, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.random_split = random_split | |
self.patch_embed_coords = PatchEmbed( | |
img_size= backbone_params['img_size'], | |
patch_size=backbone_params['patch_size'], | |
in_chans=3 if self.with_prev_mask else 2, | |
embed_dim=backbone_params['embed_dim'], | |
) | |
self.backbone = VisionTransformer_lora(**backbone_params) | |
self.neck = SimpleFPN(**neck_params) | |
self.head = SwinTransfomerSegHead(**head_params) | |
def backbone_forward(self, image, coord_features=None, gra=None): | |
coord_features = self.patch_embed_coords(coord_features) | |
backbone_features = self.backbone.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split) | |
# Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 | |
B, N, C = backbone_features.shape | |
grid_size = self.backbone.patch_embed.grid_size | |
backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) | |
multi_scale_features = self.neck(backbone_features) | |
return {'instances': self.head(multi_scale_features), 'instances_aux': None} | |