|
import math |
|
|
|
import numpy as np |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
|
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float32) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
|
def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.num_heads = num_heads |
|
self.head_dim = self.hidden_dim // self.num_heads |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_dim: |
|
raise ValueError( |
|
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
self.q_proj = nn.Sequential( |
|
nn.LayerNorm(q_dim), |
|
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), |
|
) |
|
self.k_proj = nn.Sequential( |
|
nn.LayerNorm(kv_dim), |
|
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), |
|
) |
|
self.v_proj = nn.Sequential( |
|
nn.LayerNorm(kv_dim), |
|
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), |
|
) |
|
self.o_proj = nn.Linear( |
|
self.num_heads * self.head_dim, q_dim, bias=attention_bias |
|
) |
|
|
|
def forward(self, vision_latents, queries, attention_mask): |
|
|
|
bsz, q_len, _ = queries.size() |
|
bsz, v_len, _ = vision_latents.size() |
|
|
|
query_states = self.q_proj(queries) |
|
key_states = self.k_proj(vision_latents) |
|
value_states = self.v_proj(vision_latents) |
|
|
|
query_states = query_states.view( |
|
bsz, q_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
key_states = key_states.view( |
|
bsz, v_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
value_states = value_states.view( |
|
bsz, v_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, v_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" |
|
) |
|
|
|
|
|
|
|
if query_states.device.type == "cuda" and attention_mask is not None: |
|
query_states = query_states.contiguous() |
|
key_states = key_states.contiguous() |
|
value_states = value_states.contiguous() |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attn_mask=attention_mask, |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
class AggregationBlock(nn.Module): |
|
def __init__( |
|
self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False |
|
): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.num_heads = num_heads |
|
self.head_dim = self.hidden_dim // self.num_heads |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_dim: |
|
raise ValueError( |
|
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
self.attention = attention |
|
if attention: |
|
self.attention_layer = CrossAttention( |
|
q_dim, kv_dim, hidden_dim, num_heads, attention_bias |
|
) |
|
else: |
|
self.attention_layer = MLP(kv_dim, q_dim, q_dim) |
|
|
|
def forward(self, vision_latents, queries, attention_mask): |
|
if self.attention: |
|
queries = self.attention_layer(vision_latents, queries, attention_mask) |
|
else: |
|
queries = self.attention_layer(vision_latents) |
|
|
|
return queries |
|
|
|
|
|
class MultiKVCrossAttention(nn.Module): |
|
|
|
def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False): |
|
super().__init__() |
|
|
|
self.hidden_dim = hidden_dim |
|
self.num_heads = num_heads |
|
self.head_dim = self.hidden_dim // self.num_heads |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_dim: |
|
raise ValueError( |
|
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
self.q_proj = nn.Sequential( |
|
nn.LayerNorm(q_dim), |
|
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), |
|
) |
|
self.num_of_kvs = len(kv_dim_list) |
|
for i, kv_dim in enumerate(kv_dim_list): |
|
setattr( |
|
self, |
|
"k_proj_{}".format(i), |
|
nn.Sequential( |
|
nn.LayerNorm(kv_dim), |
|
nn.Linear( |
|
kv_dim, self.num_heads * self.head_dim, bias=attention_bias |
|
), |
|
), |
|
) |
|
setattr( |
|
self, |
|
"v_proj_{}".format(i), |
|
nn.Sequential( |
|
nn.LayerNorm(kv_dim), |
|
nn.Linear( |
|
kv_dim, self.num_heads * self.head_dim, bias=attention_bias |
|
), |
|
), |
|
) |
|
self.o_proj = nn.Linear( |
|
self.num_heads * self.head_dim, q_dim, bias=attention_bias |
|
) |
|
|
|
def forward( |
|
self, |
|
queries, |
|
*vision_latents_attention_mask_list, |
|
): |
|
|
|
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] |
|
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] |
|
|
|
bsz, q_len, _ = queries.size() |
|
|
|
query_states = self.q_proj(queries) |
|
key_states = torch.cat( |
|
[ |
|
getattr(self, "k_proj_{}".format(i))(vision_latents_list[i]) |
|
for i in range(self.num_of_kvs) |
|
], |
|
dim=1, |
|
) |
|
value_states = torch.cat( |
|
[ |
|
getattr(self, "v_proj_{}".format(i))(vision_latents_list[i]) |
|
for i in range(self.num_of_kvs) |
|
], |
|
dim=1, |
|
) |
|
|
|
v_len = key_states.shape[1] |
|
|
|
query_states = query_states.view( |
|
bsz, q_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
key_states = key_states.view( |
|
bsz, v_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
value_states = value_states.view( |
|
bsz, v_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
attention_mask = torch.cat(attention_mask_list, dim=-1) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, v_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" |
|
) |
|
|
|
|
|
|
|
if query_states.device.type == "cuda" and attention_mask is not None: |
|
query_states = query_states.contiguous() |
|
key_states = key_states.contiguous() |
|
value_states = value_states.contiguous() |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attn_mask=attention_mask, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, d_in, d_hidden, d_out): |
|
super().__init__() |
|
self.linear_1 = nn.Linear(d_in, d_hidden, bias=False) |
|
self.act = nn.GELU() |
|
self.linear_2 = nn.Linear(d_hidden, d_out, bias=False) |
|
|
|
def forward(self, x): |
|
return self.linear_2(self.act(self.linear_1(x))) |
|
|
|
|
|
class VisionCrossAttentionLayer(nn.Module): |
|
def __init__( |
|
self, |
|
q_dim, |
|
context_dim, |
|
kv_dim_list, |
|
kv_size_list, |
|
hidden_dim=1024, |
|
layer_idx=0, |
|
): |
|
super().__init__() |
|
num_heads = 16 |
|
self.num_of_kvs = len(kv_dim_list) |
|
|
|
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) |
|
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) |
|
|
|
|
|
|
|
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) |
|
|
|
self.norm = nn.LayerNorm(hidden_dim) |
|
|
|
self.cross_attn = MultiKVCrossAttention( |
|
hidden_dim, kv_dim_list, hidden_dim, num_heads |
|
) |
|
self.kv_size_list = kv_size_list |
|
for i, kv_size in enumerate(kv_size_list): |
|
if kv_size > 1: |
|
setattr( |
|
self, |
|
"pos_embed_{}".format(i), |
|
nn.Parameter(torch.randn(kv_size**2, hidden_dim)), |
|
) |
|
|
|
|
|
def forward( |
|
self, |
|
queries, |
|
context_feature, |
|
*vision_latents_attention_mask_list, |
|
) -> torch.FloatTensor: |
|
|
|
residual = queries |
|
|
|
context_feature = self.proj_context(context_feature) |
|
|
|
queries = torch.cat([queries, context_feature], -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
queries = self.proj_in(queries) |
|
|
|
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] |
|
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] |
|
|
|
attention_mask_list_reshaped = [] |
|
if attention_mask_list is not None: |
|
for attention_mask in attention_mask_list: |
|
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) |
|
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) |
|
attention_mask_list_reshaped.append(attention_mask) |
|
|
|
vision_latents_pos_list = [] |
|
for i, vision_latents in enumerate(vision_latents_list): |
|
if vision_latents.shape[1] > 1: |
|
vision_latents_pos_list.append( |
|
vision_latents |
|
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to( |
|
vision_latents.dtype |
|
) |
|
) |
|
else: |
|
vision_latents_pos_list.append(vision_latents) |
|
|
|
|
|
attention_output = self.cross_attn( |
|
queries, *vision_latents_pos_list, *attention_mask_list_reshaped |
|
) |
|
|
|
|
|
queries = queries + attention_output |
|
|
|
queries = self.norm(queries) |
|
|
|
queries = self.proj_out(queries) |
|
|
|
queries = queries + residual |
|
|
|
return queries |
|
|
|
|
|
class VisionAggregationLayer(nn.Module): |
|
def __init__( |
|
self, |
|
q_dim, |
|
context_dim, |
|
kv_dim_list, |
|
kv_size_list, |
|
hidden_dim=1024, |
|
layer_idx=0, |
|
): |
|
super().__init__() |
|
num_heads = 16 |
|
self.num_of_kvs = len(kv_dim_list) |
|
|
|
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) |
|
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) |
|
|
|
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) |
|
|
|
self.norm = nn.LayerNorm(hidden_dim) |
|
|
|
if self.num_of_kvs > 1: |
|
self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs) |
|
|
|
for i, kv_size in enumerate(kv_size_list): |
|
if kv_size > 1: |
|
setattr( |
|
self, |
|
"pos_embed_{}".format(i), |
|
nn.Parameter(torch.randn(kv_size**2, hidden_dim)), |
|
) |
|
setattr( |
|
self, |
|
"aggregate_{}".format(i), |
|
AggregationBlock( |
|
True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads |
|
), |
|
) |
|
else: |
|
setattr( |
|
self, |
|
"aggregate_{}".format(i), |
|
AggregationBlock( |
|
False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads |
|
), |
|
) |
|
|
|
def forward( |
|
self, |
|
queries, |
|
context_feature, |
|
*vision_latents_attention_mask_list, |
|
) -> torch.FloatTensor: |
|
|
|
residual = queries |
|
|
|
context_feature = self.proj_context(context_feature) |
|
|
|
queries = torch.cat([queries, context_feature], -1) |
|
|
|
if self.num_of_kvs > 1: |
|
combination_weight = self.weight_mlp(queries).softmax( |
|
-1 |
|
) |
|
combination_weight = combination_weight.unsqueeze(-1) |
|
else: |
|
combination_weight = 1 |
|
|
|
queries = self.proj_in(queries) |
|
|
|
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] |
|
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] |
|
|
|
attention_mask_list_reshaped = [] |
|
if attention_mask_list is not None: |
|
for attention_mask in attention_mask_list: |
|
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) |
|
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) |
|
attention_mask_list_reshaped.append(attention_mask) |
|
|
|
vision_latents_pos_list = [] |
|
for i, vision_latents in enumerate(vision_latents_list): |
|
if vision_latents.shape[1] > 1: |
|
vision_latents_pos_list.append( |
|
vision_latents |
|
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to( |
|
vision_latents.dtype |
|
) |
|
) |
|
else: |
|
vision_latents_pos_list.append(vision_latents) |
|
|
|
aggregated_vision_latents_list = [] |
|
for i, (vision_latents, attention_mask) in enumerate( |
|
zip(vision_latents_pos_list, attention_mask_list_reshaped) |
|
): |
|
aggregated_vision_latents_list.append( |
|
getattr(self, "aggregate_{}".format(i))( |
|
vision_latents, queries, attention_mask |
|
) |
|
) |
|
|
|
aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2) |
|
|
|
queries = queries + (aggregated_vision_latents * combination_weight).sum(2) |
|
|
|
queries = self.norm(queries) |
|
|
|
queries = self.proj_out(queries) |
|
|
|
queries = queries + residual |
|
|
|
return queries |
|
|
|
|
|
class VisionTokenSampler(nn.Module): |
|
def __init__( |
|
self, |
|
q_dim, |
|
context_dim, |
|
kv_dim_list, |
|
kv_size_list, |
|
vision_hidden_size, |
|
num_of_layers=1, |
|
layer_type="joint", |
|
): |
|
super().__init__() |
|
assert layer_type in ["joint", "sep"] |
|
if layer_type == "joint": |
|
self.layers = nn.ModuleList( |
|
[ |
|
VisionCrossAttentionLayer( |
|
q_dim, |
|
context_dim, |
|
kv_dim_list, |
|
kv_size_list, |
|
vision_hidden_size, |
|
idx, |
|
) |
|
for idx in range(num_of_layers) |
|
] |
|
) |
|
else: |
|
self.layers = nn.ModuleList( |
|
[ |
|
VisionAggregationLayer( |
|
q_dim, |
|
context_dim, |
|
kv_dim_list, |
|
kv_size_list, |
|
vision_hidden_size, |
|
idx, |
|
) |
|
for idx in range(num_of_layers) |
|
] |
|
) |
|
|
|
def forward(self, queries, context_feature, *vision_latents_attention_mask_list): |
|
for layer in self.layers: |
|
queries = layer( |
|
queries, context_feature, *vision_latents_attention_mask_list |
|
) |
|
return queries |
|
|