|
""" |
|
Code based on timm https://github.com/huggingface/pytorch-image-models |
|
|
|
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from timm.layers.bottleneck_attn import PosEmbedRel |
|
from timm.layers.helpers import make_divisible |
|
from timm.layers.mlp import Mlp |
|
from timm.layers.trace_utils import _assert |
|
from timm.layers.weight_init import trunc_normal_ |
|
|
|
|
|
class CrossBottleneckAttn(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out=None, |
|
feat_size=None, |
|
stride=1, |
|
num_heads=4, |
|
dim_head=None, |
|
qk_ratio=1.0, |
|
qkv_bias=False, |
|
scale_pos_embed=False, |
|
): |
|
super().__init__() |
|
assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required" |
|
dim_out = dim_out or dim |
|
assert dim_out % num_heads == 0 |
|
|
|
self.num_heads = num_heads |
|
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads |
|
self.dim_head_v = dim_out // self.num_heads |
|
self.dim_out_qk = num_heads * self.dim_head_qk |
|
self.dim_out_v = num_heads * self.dim_head_v |
|
self.scale = self.dim_head_qk**-0.5 |
|
self.scale_pos_embed = scale_pos_embed |
|
|
|
self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) |
|
self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) |
|
|
|
|
|
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale) |
|
|
|
self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size]) |
|
mlp_ratio = 4 |
|
self.mlp = Mlp( |
|
in_features=self.dim_out_v * 2, |
|
hidden_features=int(dim * mlp_ratio), |
|
act_layer=nn.GELU, |
|
out_features=dim_out, |
|
drop=0, |
|
use_conv=True, |
|
) |
|
|
|
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) |
|
trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) |
|
trunc_normal_(self.pos_embed.height_rel, std=self.scale) |
|
trunc_normal_(self.pos_embed.width_rel, std=self.scale) |
|
|
|
def get_qkv(self, x, qvk_conv): |
|
B, C, H, W = x.shape |
|
|
|
x = qvk_conv(x) |
|
|
|
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1) |
|
|
|
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2) |
|
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) |
|
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) |
|
|
|
return q, k, v |
|
|
|
def apply_attn(self, q, k, v, B, H, W, dropout=None): |
|
if self.scale_pos_embed: |
|
attn = (q @ k + self.pos_embed(q)) * self.scale |
|
else: |
|
attn = (q @ k) * self.scale + self.pos_embed(q) |
|
attn = attn.softmax(dim=-1) |
|
if dropout: |
|
attn = dropout(attn) |
|
|
|
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) |
|
return out |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
|
|
dim = int(C / 2) |
|
x1 = x[:, :dim, :, :] |
|
x2 = x[:, dim:, :, :] |
|
|
|
_assert(H == self.pos_embed.height, "") |
|
_assert(W == self.pos_embed.width, "") |
|
|
|
q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f) |
|
q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p) |
|
|
|
|
|
out_f = self.apply_attn(q_f, k_p, v_p, B, H, W) |
|
|
|
out_p = self.apply_attn(q_p, k_f, v_f, B, H, W) |
|
|
|
x_pf = torch.cat((out_f, out_p), dim=1) |
|
x_pf = self.norm(x_pf) |
|
x_pf = self.mlp(x_pf) |
|
|
|
out = self.pool(x_pf) |
|
return out |
|
|