Spaces:
Runtime error
Runtime error
File size: 1,806 Bytes
da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpatialAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=(1, 1), stride=1), nn.BatchNorm2d(1), nn.ReLU()
)
self.sgap = nn.AvgPool2d(2)
def forward(self, x):
B, H, W, C = x.shape
x = x.reshape(B, C, H, W)
mx = torch.max(x, 1)[0].unsqueeze(1)
avg = torch.mean(x, 1).unsqueeze(1)
combined = torch.cat([mx, avg], dim=1)
fmap = self.conv(combined)
weight_map = torch.sigmoid(fmap)
out = (x * weight_map).mean(dim=(-2, -1))
return out, x * weight_map
class TokenLearner(nn.Module):
def __init__(self, S) -> None:
super().__init__()
self.S = S
self.tokenizers = nn.ModuleList([SpatialAttention() for _ in range(S)])
def forward(self, x):
B, _, _, C = x.shape
Z = torch.Tensor(B, self.S, C).to(x)
for i in range(self.S):
Ai, _ = self.tokenizers[i](x) # [B, C]
Z[:, i, :] = Ai
return Z
class TokenFuser(nn.Module):
def __init__(self, H, W, C, S) -> None:
super().__init__()
self.projection = nn.Linear(S, S, bias=False)
self.Bi = nn.Linear(C, S)
self.spatial_attn = SpatialAttention()
self.S = S
def forward(self, y, x):
B, S, C = y.shape
B, H, W, C = x.shape
Y = self.projection(y.reshape(B, C, S)).reshape(B, S, C)
Bw = torch.sigmoid(self.Bi(x)).reshape(B, H * W, S) # [B, HW, S]
BwY = torch.matmul(Bw, Y)
_, xj = self.spatial_attn(x)
xj = xj.reshape(B, H * W, C)
out = (BwY + xj).reshape(B, H, W, C)
return out
|