GraCo / isegm /model /is_text_graco_model.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
2.6 kB
import torch.nn as nn
from isegm.utils.serialization import serialize
from .is_model import ISModel
from .is_plainvit_model import SimpleFPN
from .modeling.models_vit import VisionTransformer, PatchEmbed
from .modeling.twoway_transformer import TwoWayTransformer, PositionEmbeddingRandom
from .modeling.swin_transformer import SwinTransfomerSegHead
from .modeling.clip_text_encoding import ClipTextEncoder
class TextGraCoModel(ISModel):
@serialize
def __init__(
self,
image_encoder_params={},
text_encoder_params={},
cross_encoder_params={},
neck_params={},
head_params={},
random_split=False,
**kwargs
):
super().__init__(**kwargs)
self.random_split = random_split
self.patch_embed_coords = PatchEmbed(
img_size=image_encoder_params['img_size'],
patch_size=image_encoder_params['patch_size'],
in_chans=3 if self.with_prev_mask else 2,
embed_dim=image_encoder_params['embed_dim'],
)
self.image_encoder = VisionTransformer(**image_encoder_params)
self.text_encoder = ClipTextEncoder(**text_encoder_params)
self.cross_encoder = TwoWayTransformer(**cross_encoder_params)
self.pe_layer = PositionEmbeddingRandom(cross_encoder_params["embedding_dim"] // 2)
patch_size = image_encoder_params['patch_size'][0]
self.image_embedding_size = image_encoder_params['img_size'][0] // (patch_size if patch_size > 0 else 1)
self.neck = SimpleFPN(**neck_params)
self.head = SwinTransfomerSegHead(**head_params)
def backbone_forward(self, image, coord_features=None, text=None, gra=None):
coord_features = self.patch_embed_coords(coord_features)
backbone_features = self.image_encoder.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split)
text_features = self.text_encoder(text)
text_features, backbone_features = self.cross_encoder(
backbone_features,
self.pe_layer((self.image_embedding_size, self.image_embedding_size)).unsqueeze(0),
text_features)
# Extract 4 stage image_encoder feature map: 1/4, 1/8, 1/16, 1/32
B, N, C = backbone_features.shape
grid_size = self.image_encoder.patch_embed.grid_size
backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1])
multi_scale_features = self.neck(backbone_features)
return {'instances': self.head(multi_scale_features), 'instances_aux': None}