import math import numpy as np import torch import torch.utils.checkpoint from torch import nn # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class CrossAttention(nn.Module): def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_dim = self.hidden_dim // self.num_heads if (self.head_dim * self.num_heads) != self.hidden_dim: raise ValueError( f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Sequential( nn.LayerNorm(q_dim), nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), ) self.k_proj = nn.Sequential( nn.LayerNorm(kv_dim), nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), ) self.v_proj = nn.Sequential( nn.LayerNorm(kv_dim), nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, q_dim, bias=attention_bias ) def forward(self, vision_latents, queries, attention_mask): bsz, q_len, _ = queries.size() bsz, v_len, _ = vision_latents.size() query_states = self.q_proj(queries) key_states = self.k_proj(vision_latents) value_states = self.v_proj(vision_latents) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, v_len, self.num_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, v_len, self.num_heads, self.head_dim ).transpose(1, 2) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, v_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) attn_output = self.o_proj(attn_output) return attn_output class AggregationBlock(nn.Module): def __init__( self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False ): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_dim = self.hidden_dim // self.num_heads if (self.head_dim * self.num_heads) != self.hidden_dim: raise ValueError( f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" f" and `num_heads`: {self.num_heads})." ) self.attention = attention if attention: self.attention_layer = CrossAttention( q_dim, kv_dim, hidden_dim, num_heads, attention_bias ) else: self.attention_layer = MLP(kv_dim, q_dim, q_dim) def forward(self, vision_latents, queries, attention_mask): if self.attention: queries = self.attention_layer(vision_latents, queries, attention_mask) else: queries = self.attention_layer(vision_latents) return queries class MultiKVCrossAttention(nn.Module): def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.head_dim = self.hidden_dim // self.num_heads if (self.head_dim * self.num_heads) != self.hidden_dim: raise ValueError( f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Sequential( nn.LayerNorm(q_dim), nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), ) self.num_of_kvs = len(kv_dim_list) for i, kv_dim in enumerate(kv_dim_list): setattr( self, "k_proj_{}".format(i), nn.Sequential( nn.LayerNorm(kv_dim), nn.Linear( kv_dim, self.num_heads * self.head_dim, bias=attention_bias ), ), ) setattr( self, "v_proj_{}".format(i), nn.Sequential( nn.LayerNorm(kv_dim), nn.Linear( kv_dim, self.num_heads * self.head_dim, bias=attention_bias ), ), ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, q_dim, bias=attention_bias ) def forward( self, queries, *vision_latents_attention_mask_list, ): vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] bsz, q_len, _ = queries.size() query_states = self.q_proj(queries) key_states = torch.cat( [ getattr(self, "k_proj_{}".format(i))(vision_latents_list[i]) for i in range(self.num_of_kvs) ], dim=1, ) value_states = torch.cat( [ getattr(self, "v_proj_{}".format(i))(vision_latents_list[i]) for i in range(self.num_of_kvs) ], dim=1, ) v_len = key_states.shape[1] query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( bsz, v_len, self.num_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( bsz, v_len, self.num_heads, self.head_dim ).transpose(1, 2) # if kv_weight is not None: # kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1) attention_mask = torch.cat(attention_mask_list, dim=-1) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, v_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, ) # attn_output = spda( # query_states, # key_states, # value_states, # attn_mask=attention_mask, # additional_score=kv_weight # ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) attn_output = self.o_proj(attn_output) return attn_output class MLP(nn.Module): def __init__(self, d_in, d_hidden, d_out): super().__init__() self.linear_1 = nn.Linear(d_in, d_hidden, bias=False) self.act = nn.GELU() self.linear_2 = nn.Linear(d_hidden, d_out, bias=False) def forward(self, x): return self.linear_2(self.act(self.linear_1(x))) class VisionCrossAttentionLayer(nn.Module): def __init__( self, q_dim, context_dim, kv_dim_list, kv_size_list, hidden_dim=1024, layer_idx=0, ): super().__init__() num_heads = 16 self.num_of_kvs = len(kv_dim_list) self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) # if self.num_of_kvs > 1: # self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs) # self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs))) self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) self.norm = nn.LayerNorm(hidden_dim) self.cross_attn = MultiKVCrossAttention( hidden_dim, kv_dim_list, hidden_dim, num_heads ) self.kv_size_list = kv_size_list for i, kv_size in enumerate(kv_size_list): if kv_size > 1: setattr( self, "pos_embed_{}".format(i), nn.Parameter(torch.randn(kv_size**2, hidden_dim)), ) # self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False) def forward( self, queries, context_feature, *vision_latents_attention_mask_list, ) -> torch.FloatTensor: residual = queries # queries = self.proj_in(queries) context_feature = self.proj_context(context_feature) # queries = queries + context_feature queries = torch.cat([queries, context_feature], -1) # if self.num_of_kvs > 1: # kv_weight = self.weight_mlp(queries) # B * 1 * num_tower # kv_weight = kv_weight + self.tower_weight.view(1, 1, -1) # kv_weight = kv_weight.softmax(-1) # kv_number_list = [size**2 for size in self.kv_size_list] # kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1) # else: # kv_weight = None queries = self.proj_in(queries) vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] attention_mask_list_reshaped = [] if attention_mask_list is not None: for attention_mask in attention_mask_list: attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) attention_mask_list_reshaped.append(attention_mask) vision_latents_pos_list = [] for i, vision_latents in enumerate(vision_latents_list): if vision_latents.shape[1] > 1: vision_latents_pos_list.append( vision_latents + getattr(self, "pos_embed_{}".format(i))[None, :, :].to( vision_latents.dtype ) ) else: vision_latents_pos_list.append(vision_latents) # Cross Attention attention_output = self.cross_attn( queries, *vision_latents_pos_list, *attention_mask_list_reshaped ) # attention_output = (attention_output * combination_weight).sum(2) queries = queries + attention_output queries = self.norm(queries) queries = self.proj_out(queries) queries = queries + residual return queries class VisionAggregationLayer(nn.Module): def __init__( self, q_dim, context_dim, kv_dim_list, kv_size_list, hidden_dim=1024, layer_idx=0, ): super().__init__() num_heads = 16 self.num_of_kvs = len(kv_dim_list) self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) self.norm = nn.LayerNorm(hidden_dim) if self.num_of_kvs > 1: self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs) for i, kv_size in enumerate(kv_size_list): if kv_size > 1: setattr( self, "pos_embed_{}".format(i), nn.Parameter(torch.randn(kv_size**2, hidden_dim)), ) setattr( self, "aggregate_{}".format(i), AggregationBlock( True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads ), ) else: setattr( self, "aggregate_{}".format(i), AggregationBlock( False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads ), ) def forward( self, queries, context_feature, *vision_latents_attention_mask_list, ) -> torch.FloatTensor: residual = queries # queries = self.proj_in(queries) context_feature = self.proj_context(context_feature) # queries = queries + context_feature queries = torch.cat([queries, context_feature], -1) if self.num_of_kvs > 1: combination_weight = self.weight_mlp(queries).softmax( -1 ) # B * 1 * num_tower combination_weight = combination_weight.unsqueeze(-1) else: combination_weight = 1 queries = self.proj_in(queries) vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] attention_mask_list_reshaped = [] if attention_mask_list is not None: for attention_mask in attention_mask_list: attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) attention_mask_list_reshaped.append(attention_mask) vision_latents_pos_list = [] for i, vision_latents in enumerate(vision_latents_list): if vision_latents.shape[1] > 1: vision_latents_pos_list.append( vision_latents + getattr(self, "pos_embed_{}".format(i))[None, :, :].to( vision_latents.dtype ) ) else: vision_latents_pos_list.append(vision_latents) aggregated_vision_latents_list = [] for i, (vision_latents, attention_mask) in enumerate( zip(vision_latents_pos_list, attention_mask_list_reshaped) ): aggregated_vision_latents_list.append( getattr(self, "aggregate_{}".format(i))( vision_latents, queries, attention_mask ) ) aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2) queries = queries + (aggregated_vision_latents * combination_weight).sum(2) queries = self.norm(queries) queries = self.proj_out(queries) queries = queries + residual return queries class VisionTokenSampler(nn.Module): def __init__( self, q_dim, context_dim, kv_dim_list, kv_size_list, vision_hidden_size, num_of_layers=1, layer_type="joint", ): super().__init__() assert layer_type in ["joint", "sep"] if layer_type == "joint": self.layers = nn.ModuleList( [ VisionCrossAttentionLayer( q_dim, context_dim, kv_dim_list, kv_size_list, vision_hidden_size, idx, ) for idx in range(num_of_layers) ] ) else: self.layers = nn.ModuleList( [ VisionAggregationLayer( q_dim, context_dim, kv_dim_list, kv_size_list, vision_hidden_size, idx, ) for idx in range(num_of_layers) ] ) def forward(self, queries, context_feature, *vision_latents_attention_mask_list): for layer in self.layers: queries = layer( queries, context_feature, *vision_latents_attention_mask_list ) return queries