import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from efficientat.models.utils import collapse_dim class MultiHeadAttentionPooling(nn.Module): """Multi-Head Attention as used in PSLA paper (https://arxiv.org/pdf/2102.01243.pdf) """ def __init__(self, in_dim, out_dim, att_activation: str = 'sigmoid', clf_activation: str = 'ident', num_heads: int = 4, epsilon: float = 1e-7): super(MultiHeadAttentionPooling, self).__init__() self.in_dim = in_dim self.out_dim = out_dim self.num_heads = num_heads self.epsilon = epsilon self.att_activation = att_activation self.clf_activation = clf_activation # out size: out dim x 2 (att and clf paths) x num_heads self.subspace_proj = nn.Linear(self.in_dim, self.out_dim * 2 * self.num_heads) self.head_weight = nn.Parameter(torch.tensor([1.0 / self.num_heads] * self.num_heads).view(1, -1, 1)) def activate(self, x, activation): if activation == 'linear': return x elif activation == 'relu': return F.relu(x) elif activation == 'sigmoid': return torch.sigmoid(x) elif activation == 'softmax': return F.softmax(x, dim=1) elif activation == 'ident': return x def forward(self, x) -> Tensor: """x: Tensor of size (batch_size, channels, frequency bands, sequence length) """ x = collapse_dim(x, dim=2) # results in tensor of size (batch_size, channels, sequence_length) x = x.transpose(1, 2) # results in tensor of size (batch_size, sequence_length, channels) b, n, c = x.shape x = self.subspace_proj(x).reshape(b, n, 2, self.num_heads, self.out_dim).permute(2, 0, 3, 1, 4) att, val = x[0], x[1] val = self.activate(val, self.clf_activation) att = self.activate(att, self.att_activation) att = torch.clamp(att, self.epsilon, 1. - self.epsilon) att = att / torch.sum(att, dim=2, keepdim=True) out = torch.sum(att * val, dim=2) * self.head_weight out = torch.sum(out, dim=1) return out