import torch.nn as nn from isegm.utils.serialization import serialize from .is_model import ISModel from .is_plainvit_model import SimpleFPN from .modeling.models_vit import VisionTransformer, PatchEmbed from .modeling.twoway_transformer import TwoWayTransformer, PositionEmbeddingRandom from .modeling.swin_transformer import SwinTransfomerSegHead from .modeling.clip_text_encoding import ClipTextEncoder class TextGraCoModel(ISModel): @serialize def __init__( self, image_encoder_params={}, text_encoder_params={}, cross_encoder_params={}, neck_params={}, head_params={}, random_split=False, **kwargs ): super().__init__(**kwargs) self.random_split = random_split self.patch_embed_coords = PatchEmbed( img_size=image_encoder_params['img_size'], patch_size=image_encoder_params['patch_size'], in_chans=3 if self.with_prev_mask else 2, embed_dim=image_encoder_params['embed_dim'], ) self.image_encoder = VisionTransformer(**image_encoder_params) self.text_encoder = ClipTextEncoder(**text_encoder_params) self.cross_encoder = TwoWayTransformer(**cross_encoder_params) self.pe_layer = PositionEmbeddingRandom(cross_encoder_params["embedding_dim"] // 2) patch_size = image_encoder_params['patch_size'][0] self.image_embedding_size = image_encoder_params['img_size'][0] // (patch_size if patch_size > 0 else 1) self.neck = SimpleFPN(**neck_params) self.head = SwinTransfomerSegHead(**head_params) def backbone_forward(self, image, coord_features=None, text=None, gra=None): coord_features = self.patch_embed_coords(coord_features) backbone_features = self.image_encoder.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split) text_features = self.text_encoder(text) text_features, backbone_features = self.cross_encoder( backbone_features, self.pe_layer((self.image_embedding_size, self.image_embedding_size)).unsqueeze(0), text_features) # Extract 4 stage image_encoder feature map: 1/4, 1/8, 1/16, 1/32 B, N, C = backbone_features.shape grid_size = self.image_encoder.patch_embed.grid_size backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) multi_scale_features = self.neck(backbone_features) return {'instances': self.head(multi_scale_features), 'instances_aux': None}