File size: 2,595 Bytes
6d1366a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch.nn as nn
from isegm.utils.serialization import serialize
from .is_model import ISModel
from .is_plainvit_model import SimpleFPN
from .modeling.models_vit import VisionTransformer, PatchEmbed
from .modeling.twoway_transformer import TwoWayTransformer, PositionEmbeddingRandom
from .modeling.swin_transformer import SwinTransfomerSegHead
from .modeling.clip_text_encoding import ClipTextEncoder


class TextGraCoModel(ISModel):
    @serialize
    def __init__(
        self,
        image_encoder_params={},
        text_encoder_params={},
        cross_encoder_params={},
        neck_params={},
        head_params={},
        random_split=False,
        **kwargs
        ):

        super().__init__(**kwargs)
        self.random_split = random_split

        self.patch_embed_coords = PatchEmbed(
            img_size=image_encoder_params['img_size'],
            patch_size=image_encoder_params['patch_size'], 
            in_chans=3 if self.with_prev_mask else 2, 
            embed_dim=image_encoder_params['embed_dim'],
        )

        self.image_encoder = VisionTransformer(**image_encoder_params)
        self.text_encoder = ClipTextEncoder(**text_encoder_params)
        self.cross_encoder = TwoWayTransformer(**cross_encoder_params)

        self.pe_layer = PositionEmbeddingRandom(cross_encoder_params["embedding_dim"] // 2)
        patch_size = image_encoder_params['patch_size'][0]
        self.image_embedding_size = image_encoder_params['img_size'][0] // (patch_size if patch_size > 0 else 1)

        self.neck = SimpleFPN(**neck_params)
        self.head = SwinTransfomerSegHead(**head_params)

    def backbone_forward(self, image, coord_features=None, text=None, gra=None):
        coord_features = self.patch_embed_coords(coord_features)
        backbone_features = self.image_encoder.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split)
        text_features = self.text_encoder(text)

        text_features, backbone_features = self.cross_encoder(
            backbone_features, 
            self.pe_layer((self.image_embedding_size, self.image_embedding_size)).unsqueeze(0),
            text_features)

        # Extract 4 stage image_encoder feature map: 1/4, 1/8, 1/16, 1/32
        B, N, C = backbone_features.shape
        grid_size = self.image_encoder.patch_embed.grid_size

        backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1])
        multi_scale_features = self.neck(backbone_features)

        return {'instances': self.head(multi_scale_features), 'instances_aux': None}