Spaces:
Running
Running
import torch.nn as nn | |
from isegm.utils.serialization import serialize | |
from .is_model import ISModel | |
from .is_plainvit_model import SimpleFPN | |
from .modeling.models_vit import VisionTransformer, PatchEmbed | |
from .modeling.twoway_transformer import TwoWayTransformer, PositionEmbeddingRandom | |
from .modeling.swin_transformer import SwinTransfomerSegHead | |
from .modeling.clip_text_encoding import ClipTextEncoder | |
class TextGraCoModel(ISModel): | |
def __init__( | |
self, | |
image_encoder_params={}, | |
text_encoder_params={}, | |
cross_encoder_params={}, | |
neck_params={}, | |
head_params={}, | |
random_split=False, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.random_split = random_split | |
self.patch_embed_coords = PatchEmbed( | |
img_size=image_encoder_params['img_size'], | |
patch_size=image_encoder_params['patch_size'], | |
in_chans=3 if self.with_prev_mask else 2, | |
embed_dim=image_encoder_params['embed_dim'], | |
) | |
self.image_encoder = VisionTransformer(**image_encoder_params) | |
self.text_encoder = ClipTextEncoder(**text_encoder_params) | |
self.cross_encoder = TwoWayTransformer(**cross_encoder_params) | |
self.pe_layer = PositionEmbeddingRandom(cross_encoder_params["embedding_dim"] // 2) | |
patch_size = image_encoder_params['patch_size'][0] | |
self.image_embedding_size = image_encoder_params['img_size'][0] // (patch_size if patch_size > 0 else 1) | |
self.neck = SimpleFPN(**neck_params) | |
self.head = SwinTransfomerSegHead(**head_params) | |
def backbone_forward(self, image, coord_features=None, text=None, gra=None): | |
coord_features = self.patch_embed_coords(coord_features) | |
backbone_features = self.image_encoder.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split) | |
text_features = self.text_encoder(text) | |
text_features, backbone_features = self.cross_encoder( | |
backbone_features, | |
self.pe_layer((self.image_embedding_size, self.image_embedding_size)).unsqueeze(0), | |
text_features) | |
# Extract 4 stage image_encoder feature map: 1/4, 1/8, 1/16, 1/32 | |
B, N, C = backbone_features.shape | |
grid_size = self.image_encoder.patch_embed.grid_size | |
backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) | |
multi_scale_features = self.neck(backbone_features) | |
return {'instances': self.head(multi_scale_features), 'instances_aux': None} | |