Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import numpy as np | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 | |
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
""" | |
grid_size: int of the grid height and width | |
return: | |
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
grid_h = np.arange(grid_size, dtype=np.float32) | |
grid_w = np.arange(grid_size, dtype=np.float32) | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_size, grid_size]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
assert embed_dim % 2 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=np.float32) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
class CrossAttention(nn.Module): | |
def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.head_dim = self.hidden_dim // self.num_heads | |
if (self.head_dim * self.num_heads) != self.hidden_dim: | |
raise ValueError( | |
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" | |
f" and `num_heads`: {self.num_heads})." | |
) | |
self.q_proj = nn.Sequential( | |
nn.LayerNorm(q_dim), | |
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
) | |
self.k_proj = nn.Sequential( | |
nn.LayerNorm(kv_dim), | |
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
) | |
self.v_proj = nn.Sequential( | |
nn.LayerNorm(kv_dim), | |
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
) | |
self.o_proj = nn.Linear( | |
self.num_heads * self.head_dim, q_dim, bias=attention_bias | |
) | |
def forward(self, vision_latents, queries, attention_mask): | |
bsz, q_len, _ = queries.size() | |
bsz, v_len, _ = vision_latents.size() | |
query_states = self.q_proj(queries) | |
key_states = self.k_proj(vision_latents) | |
value_states = self.v_proj(vision_latents) | |
query_states = query_states.view( | |
bsz, q_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
key_states = key_states.view( | |
bsz, v_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
value_states = value_states.view( | |
bsz, v_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
if attention_mask is not None: | |
if attention_mask.size() != (bsz, 1, q_len, v_len): | |
raise ValueError( | |
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" | |
) | |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | |
# Reference: https://github.com/pytorch/pytorch/issues/112577. | |
if query_states.device.type == "cuda" and attention_mask is not None: | |
query_states = query_states.contiguous() | |
key_states = key_states.contiguous() | |
value_states = value_states.contiguous() | |
attn_output = torch.nn.functional.scaled_dot_product_attention( | |
query_states, | |
key_states, | |
value_states, | |
attn_mask=attention_mask, | |
) | |
attn_output = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) | |
attn_output = self.o_proj(attn_output) | |
return attn_output | |
class AggregationBlock(nn.Module): | |
def __init__( | |
self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False | |
): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.head_dim = self.hidden_dim // self.num_heads | |
if (self.head_dim * self.num_heads) != self.hidden_dim: | |
raise ValueError( | |
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" | |
f" and `num_heads`: {self.num_heads})." | |
) | |
self.attention = attention | |
if attention: | |
self.attention_layer = CrossAttention( | |
q_dim, kv_dim, hidden_dim, num_heads, attention_bias | |
) | |
else: | |
self.attention_layer = MLP(kv_dim, q_dim, q_dim) | |
def forward(self, vision_latents, queries, attention_mask): | |
if self.attention: | |
queries = self.attention_layer(vision_latents, queries, attention_mask) | |
else: | |
queries = self.attention_layer(vision_latents) | |
return queries | |
class MultiKVCrossAttention(nn.Module): | |
def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.head_dim = self.hidden_dim // self.num_heads | |
if (self.head_dim * self.num_heads) != self.hidden_dim: | |
raise ValueError( | |
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}" | |
f" and `num_heads`: {self.num_heads})." | |
) | |
self.q_proj = nn.Sequential( | |
nn.LayerNorm(q_dim), | |
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias), | |
) | |
self.num_of_kvs = len(kv_dim_list) | |
for i, kv_dim in enumerate(kv_dim_list): | |
setattr( | |
self, | |
"k_proj_{}".format(i), | |
nn.Sequential( | |
nn.LayerNorm(kv_dim), | |
nn.Linear( | |
kv_dim, self.num_heads * self.head_dim, bias=attention_bias | |
), | |
), | |
) | |
setattr( | |
self, | |
"v_proj_{}".format(i), | |
nn.Sequential( | |
nn.LayerNorm(kv_dim), | |
nn.Linear( | |
kv_dim, self.num_heads * self.head_dim, bias=attention_bias | |
), | |
), | |
) | |
self.o_proj = nn.Linear( | |
self.num_heads * self.head_dim, q_dim, bias=attention_bias | |
) | |
def forward( | |
self, | |
queries, | |
*vision_latents_attention_mask_list, | |
): | |
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] | |
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] | |
bsz, q_len, _ = queries.size() | |
query_states = self.q_proj(queries) | |
key_states = torch.cat( | |
[ | |
getattr(self, "k_proj_{}".format(i))(vision_latents_list[i]) | |
for i in range(self.num_of_kvs) | |
], | |
dim=1, | |
) | |
value_states = torch.cat( | |
[ | |
getattr(self, "v_proj_{}".format(i))(vision_latents_list[i]) | |
for i in range(self.num_of_kvs) | |
], | |
dim=1, | |
) | |
v_len = key_states.shape[1] | |
query_states = query_states.view( | |
bsz, q_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
key_states = key_states.view( | |
bsz, v_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
value_states = value_states.view( | |
bsz, v_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
# if kv_weight is not None: | |
# kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1) | |
attention_mask = torch.cat(attention_mask_list, dim=-1) | |
if attention_mask is not None: | |
if attention_mask.size() != (bsz, 1, q_len, v_len): | |
raise ValueError( | |
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}" | |
) | |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | |
# Reference: https://github.com/pytorch/pytorch/issues/112577. | |
if query_states.device.type == "cuda" and attention_mask is not None: | |
query_states = query_states.contiguous() | |
key_states = key_states.contiguous() | |
value_states = value_states.contiguous() | |
attn_output = torch.nn.functional.scaled_dot_product_attention( | |
query_states, | |
key_states, | |
value_states, | |
attn_mask=attention_mask, | |
) | |
# attn_output = spda( | |
# query_states, | |
# key_states, | |
# value_states, | |
# attn_mask=attention_mask, | |
# additional_score=kv_weight | |
# ) | |
attn_output = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim) | |
attn_output = self.o_proj(attn_output) | |
return attn_output | |
class MLP(nn.Module): | |
def __init__(self, d_in, d_hidden, d_out): | |
super().__init__() | |
self.linear_1 = nn.Linear(d_in, d_hidden, bias=False) | |
self.act = nn.GELU() | |
self.linear_2 = nn.Linear(d_hidden, d_out, bias=False) | |
def forward(self, x): | |
return self.linear_2(self.act(self.linear_1(x))) | |
class VisionCrossAttentionLayer(nn.Module): | |
def __init__( | |
self, | |
q_dim, | |
context_dim, | |
kv_dim_list, | |
kv_size_list, | |
hidden_dim=1024, | |
layer_idx=0, | |
): | |
super().__init__() | |
num_heads = 16 | |
self.num_of_kvs = len(kv_dim_list) | |
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) | |
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) | |
# if self.num_of_kvs > 1: | |
# self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs) | |
# self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs))) | |
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) | |
self.norm = nn.LayerNorm(hidden_dim) | |
self.cross_attn = MultiKVCrossAttention( | |
hidden_dim, kv_dim_list, hidden_dim, num_heads | |
) | |
self.kv_size_list = kv_size_list | |
for i, kv_size in enumerate(kv_size_list): | |
if kv_size > 1: | |
setattr( | |
self, | |
"pos_embed_{}".format(i), | |
nn.Parameter(torch.randn(kv_size**2, hidden_dim)), | |
) | |
# self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False) | |
def forward( | |
self, | |
queries, | |
context_feature, | |
*vision_latents_attention_mask_list, | |
) -> torch.FloatTensor: | |
residual = queries | |
# queries = self.proj_in(queries) | |
context_feature = self.proj_context(context_feature) | |
# queries = queries + context_feature | |
queries = torch.cat([queries, context_feature], -1) | |
# if self.num_of_kvs > 1: | |
# kv_weight = self.weight_mlp(queries) # B * 1 * num_tower | |
# kv_weight = kv_weight + self.tower_weight.view(1, 1, -1) | |
# kv_weight = kv_weight.softmax(-1) | |
# kv_number_list = [size**2 for size in self.kv_size_list] | |
# kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1) | |
# else: | |
# kv_weight = None | |
queries = self.proj_in(queries) | |
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] | |
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] | |
attention_mask_list_reshaped = [] | |
if attention_mask_list is not None: | |
for attention_mask in attention_mask_list: | |
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) | |
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) | |
attention_mask_list_reshaped.append(attention_mask) | |
vision_latents_pos_list = [] | |
for i, vision_latents in enumerate(vision_latents_list): | |
if vision_latents.shape[1] > 1: | |
vision_latents_pos_list.append( | |
vision_latents | |
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to( | |
vision_latents.dtype | |
) | |
) | |
else: | |
vision_latents_pos_list.append(vision_latents) | |
# Cross Attention | |
attention_output = self.cross_attn( | |
queries, *vision_latents_pos_list, *attention_mask_list_reshaped | |
) | |
# attention_output = (attention_output * combination_weight).sum(2) | |
queries = queries + attention_output | |
queries = self.norm(queries) | |
queries = self.proj_out(queries) | |
queries = queries + residual | |
return queries | |
class VisionAggregationLayer(nn.Module): | |
def __init__( | |
self, | |
q_dim, | |
context_dim, | |
kv_dim_list, | |
kv_size_list, | |
hidden_dim=1024, | |
layer_idx=0, | |
): | |
super().__init__() | |
num_heads = 16 | |
self.num_of_kvs = len(kv_dim_list) | |
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False) | |
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False) | |
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim) | |
self.norm = nn.LayerNorm(hidden_dim) | |
if self.num_of_kvs > 1: | |
self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs) | |
for i, kv_size in enumerate(kv_size_list): | |
if kv_size > 1: | |
setattr( | |
self, | |
"pos_embed_{}".format(i), | |
nn.Parameter(torch.randn(kv_size**2, hidden_dim)), | |
) | |
setattr( | |
self, | |
"aggregate_{}".format(i), | |
AggregationBlock( | |
True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads | |
), | |
) | |
else: | |
setattr( | |
self, | |
"aggregate_{}".format(i), | |
AggregationBlock( | |
False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads | |
), | |
) | |
def forward( | |
self, | |
queries, | |
context_feature, | |
*vision_latents_attention_mask_list, | |
) -> torch.FloatTensor: | |
residual = queries | |
# queries = self.proj_in(queries) | |
context_feature = self.proj_context(context_feature) | |
# queries = queries + context_feature | |
queries = torch.cat([queries, context_feature], -1) | |
if self.num_of_kvs > 1: | |
combination_weight = self.weight_mlp(queries).softmax( | |
-1 | |
) # B * 1 * num_tower | |
combination_weight = combination_weight.unsqueeze(-1) | |
else: | |
combination_weight = 1 | |
queries = self.proj_in(queries) | |
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs] | |
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :] | |
attention_mask_list_reshaped = [] | |
if attention_mask_list is not None: | |
for attention_mask in attention_mask_list: | |
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) | |
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1) | |
attention_mask_list_reshaped.append(attention_mask) | |
vision_latents_pos_list = [] | |
for i, vision_latents in enumerate(vision_latents_list): | |
if vision_latents.shape[1] > 1: | |
vision_latents_pos_list.append( | |
vision_latents | |
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to( | |
vision_latents.dtype | |
) | |
) | |
else: | |
vision_latents_pos_list.append(vision_latents) | |
aggregated_vision_latents_list = [] | |
for i, (vision_latents, attention_mask) in enumerate( | |
zip(vision_latents_pos_list, attention_mask_list_reshaped) | |
): | |
aggregated_vision_latents_list.append( | |
getattr(self, "aggregate_{}".format(i))( | |
vision_latents, queries, attention_mask | |
) | |
) | |
aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2) | |
queries = queries + (aggregated_vision_latents * combination_weight).sum(2) | |
queries = self.norm(queries) | |
queries = self.proj_out(queries) | |
queries = queries + residual | |
return queries | |
class VisionTokenSampler(nn.Module): | |
def __init__( | |
self, | |
q_dim, | |
context_dim, | |
kv_dim_list, | |
kv_size_list, | |
vision_hidden_size, | |
num_of_layers=1, | |
layer_type="joint", | |
): | |
super().__init__() | |
assert layer_type in ["joint", "sep"] | |
if layer_type == "joint": | |
self.layers = nn.ModuleList( | |
[ | |
VisionCrossAttentionLayer( | |
q_dim, | |
context_dim, | |
kv_dim_list, | |
kv_size_list, | |
vision_hidden_size, | |
idx, | |
) | |
for idx in range(num_of_layers) | |
] | |
) | |
else: | |
self.layers = nn.ModuleList( | |
[ | |
VisionAggregationLayer( | |
q_dim, | |
context_dim, | |
kv_dim_list, | |
kv_size_list, | |
vision_hidden_size, | |
idx, | |
) | |
for idx in range(num_of_layers) | |
] | |
) | |
def forward(self, queries, context_feature, *vision_latents_attention_mask_list): | |
for layer in self.layers: | |
queries = layer( | |
queries, context_feature, *vision_latents_attention_mask_list | |
) | |
return queries | |