# ------------------------------------------------------------------------ # Copyright (c) 2023-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------ """Image decoder.""" try: from flash_attn import flash_attn_func except ImportError: flash_attn_func = None import torch from torch import nn class TransposedLayerNorm(nn.LayerNorm): """LayerNorm with pre-transposed spatial axes.""" def forward(self, input): return super().forward(input.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class MLP(nn.Module): """Two layers MLP.""" def __init__(self, dim, mlp_dim, activation_type="ReLU"): super(MLP, self).__init__() self.fc1 = nn.Linear(dim, mlp_dim) self.fc2 = nn.Linear(mlp_dim, dim) self.activation = getattr(nn, activation_type)() self.activation.inplace = True def forward(self, x): return self.fc2(self.activation(self.fc1(x))) class Attention(nn.Module): """Multi-head attention.""" def __init__(self, dim=256, num_heads=8, attn_ratio=1): super(Attention, self).__init__() self.num_heads = num_heads or dim // 64 self.head_dim = int(dim * attn_ratio) // self.num_heads self.q_proj = nn.Linear(dim, self.num_heads * self.head_dim) self.k_proj = nn.Linear(dim, self.num_heads * self.head_dim) self.v_proj = nn.Linear(dim, self.num_heads * self.head_dim) self.proj = nn.Linear(self.num_heads * self.head_dim, dim) self.scale = self.head_dim**-0.5 def forward(self, q, k, v): q = self.q_proj(q).view((-1, q.size(1), self.num_heads, self.head_dim)) k = self.k_proj(k).view((-1, k.size(1), self.num_heads, self.head_dim)) v = self.v_proj(v).view((-1, v.size(1), self.num_heads, self.head_dim)) o = flash_attn_func(q, k, v, softmax_scale=self.scale) return self.proj(o.flatten(2)) class Block(nn.Module): """Transformer block.""" def __init__( self, dim=256, num_heads=8, attn_ratio=0.5, mlp_dim=2048, activation_type="ReLU", skip_first_query_pos=False, ): super(Block, self).__init__() self.self_attn = Attention(dim, num_heads) self.norm1 = nn.LayerNorm(dim) self.cross_attn_token_to_image = Attention(dim, num_heads, attn_ratio) self.norm2 = nn.LayerNorm(dim) self.mlp = MLP(dim, mlp_dim, activation_type) self.norm3 = nn.LayerNorm(dim) self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio) self.norm4 = nn.LayerNorm(dim) self.dropout = nn.Dropout(0.1, inplace=True) self.skip_first_query_pos = skip_first_query_pos def forward(self, query, key, query_pos, key_pos): if self.skip_first_query_pos: query = self.norm1(self.self_attn(query, query, query)) else: q = query + query_pos query = self.norm1(self.dropout(self.self_attn(q, q, query)).add_(query)) q, k = query + query_pos, key + key_pos query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query)) query = self.norm3(self.dropout(self.mlp(query)).add_(query)) key = self.norm4(self.cross_attn_image_to_token(k, query + query_pos, query).add_(key)) return query, key class Transformer(nn.Module): """Two-way transformer decoder.""" def __init__( self, embed_dim=256, num_heads=8, attn_ratio=0.5, mlp_dim=2048, activation_type="ReLU", depth=2, ): super(Transformer, self).__init__() self.blocks = nn.ModuleList( Block( embed_dim, num_heads, attn_ratio=attn_ratio, mlp_dim=mlp_dim, activation_type=activation_type, skip_first_query_pos=i == 0, ) for i in range(depth) ) self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio) self.norm = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(0.1, inplace=True) def forward(self, query, key, query_pos, key_pos): for blk in self.blocks: query, key = blk(query, key, query_pos, key_pos) q, k = query + query_pos, key + key_pos query = self.norm(self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query)) return query, key class Predictor(nn.Module): """MLP predictor.""" def __init__(self, in_dim, out_dim, mlp_dim=None, depth=3): super(Predictor, self).__init__() mlp_dims = [mlp_dim or in_dim] * (depth - 1) in_dims, out_dims = [in_dim] + mlp_dims, mlp_dims + [out_dim] self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip(in_dims, out_dims)) def forward(self, x): for fc in self.layers[:-1]: x = nn.functional.relu(fc(x), inplace=True) return self.layers[-1](x) class ImageDecoder(nn.Module): """Module to decode region tokens and masks.""" def __init__(self, depth, embed_dim, num_heads, num_mask_tokens=4, sem_embed_dim=1024): super(ImageDecoder, self).__init__() self.embed_dim = embed_dim self.num_mask_tokens = num_mask_tokens self.transformer = Transformer(embed_dim, num_heads, depth=depth) self.iou_token = nn.Embedding(1, embed_dim) self.sem_tokens = nn.Embedding(num_mask_tokens, embed_dim) self.mask_tokens = nn.Embedding(num_mask_tokens, embed_dim) self.output_conv = nn.Sequential( nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2), TransposedLayerNorm(embed_dim // 4), nn.GELU(), nn.ConvTranspose2d(embed_dim // 4, embed_dim // 8, 2, 2), nn.GELU(), ) self.mask_pred = nn.ModuleList( Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens) ) self.iou_pred = Predictor(embed_dim, num_mask_tokens) self.sem_pred = Predictor(embed_dim, sem_embed_dim, sem_embed_dim) def get_outputs(self, inputs): img_embeds = inputs["img_embeds"] sparse_embeds = inputs["sparse_embeds"] ims_per_batch = img_embeds.size(0) prompts_per_batch = sparse_embeds.size(0) img_embed_size = img_embeds.shape[2:-1] # Prepare query. tokens = [self.sem_tokens.weight, self.iou_token.weight, self.mask_tokens.weight] query = torch.cat(tokens).unsqueeze_(0).expand(prompts_per_batch, -1, -1) query = torch.cat((query, sparse_embeds), dim=1) num_tokens = query.shape[1] - sparse_embeds.shape[1] # Prepare key. key = img_embeds.expand(-1, prompts_per_batch // ims_per_batch, -1, -1, -1) key = key.flatten(0, 1).flatten(1, 2) # Decode. query, key = self.transformer(query, key, query, inputs["img_pos"]) # Upscale key. key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size) mask_embeds = self.output_conv(key).flatten(2) # Unpack query. sem_tokens = query[:, : self.num_mask_tokens] sam_tokens = query[:, self.num_mask_tokens : num_tokens].unbind(1) iou_tokens, mask_tokens = sam_tokens[0], sam_tokens[1:] # Predict. mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)] mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds mask_pred_size = list(4 * embed_size for embed_size in img_embed_size) mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size) outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred} outputs["sem_tokens"] = sem_tokens.unsqueeze_(2) outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"].flatten(2)) return outputs def forward(self, inputs): outputs = self.get_outputs(inputs) outputs["iou_pred"] = outputs["iou_pred"].float() return outputs