""" Learnable linear attention feature map classes and functions """ import torch import torch.nn as nn import torch.nn.functional as F def init_feature_map(name: str, mlp: nn.Module, **kwargs: dict): """ Initialize feature map final activation for linear attention """ return FeatureMap(activation_name=name, mlp=mlp, **kwargs) def init_feature_map_act(name: str, fullspace: bool = True, **kwargs): """ Initialize feature map final activation for linear attention """ if name == 'softmax_dim' and fullspace: return SoftmaxDim(**kwargs) elif name == 'softmax_dim' and not fullspace: return SoftmaxDimHalfspace(**kwargs) elif name == 'exp_dim' and fullspace: return Exp(**kwargs) elif name == 'exp_dim' and not fullspace: return ExpHalfspace(**kwargs) elif name == 'pos_elu': return PosELU(**kwargs) elif name == 'relu': return ReLU(**kwargs) else: raise NotImplementedError def init_learned_kernel(name: str, **kwargs: any): """ Initialize feature map MLP for linear attention """ if name == 'untied_head_einsum': return FeatureMapMLP(**kwargs) elif name == 'untied_head_adapter': return FeatureMapAdapter(**kwargs) else: raise NotImplementedError class FeatureMap(nn.Module): """ Final 'activation' of feature map. Can probably be combined with `FeatureMapMLP` below Full feature map is like f(xW + b) -> This is the `f` part """ def __init__(self, activation_name: str, head_dim_idx: int = -1, eps: float = 1e-12, mlp: nn.Module = None, fullspace: bool = True,): super().__init__() self.head_dim_idx = head_dim_idx self.eps = eps self.mlp = mlp if mlp is not None else nn.Identity() self.activation = init_feature_map_act(activation_name, fullspace, eps=eps) def forward(self, x: torch.Tensor, *mlp_args: any, **mlp_kwargs: any): """ Assume x.shape is (batch_size, n_heads, seq_len, head_dim) """ return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x) def q_map(self, *args: any, **kwargs: any): """ Use for inference in case q and k feature maps differ """ return self.forward(*args, **kwargs) def k_map(self, *args: any, **kwargs: any): """ Use for inference in case q and k feature maps differ """ return self.forward(*args, **kwargs) # ----------------------- # Feature map activations # ----------------------- class FeatureMapAct(nn.Module): """ Base class for feature map activations """ def __init__(self, eps: float = 1e-12): super().__init__() self.eps = eps def forward(self, x: torch.Tensor, *args: any, **kwargs: any): """ x.shape is (batch_size, n_heads, seq_len, head_dim) """ return x class PosELU(FeatureMapAct): """ 1 + ELU activation as in https://arxiv.org/abs/2006.16236 """ def forward(self, x: torch.Tensor, *args: any, **kwargs: any): return (1 + F.elu(x)).clamp(min=self.eps) class ReLU(FeatureMapAct): """ ReLU activation as in https://arxiv.org/abs/2103.13076 """ def forward(self, x: torch.Tensor, *args: any, **kwargs: any): return F.relu(x).clamp(min=self.eps) class SoftmaxDim(FeatureMapAct): """ Softmax activation as in https://arxiv.org/abs/2402.04347 """ def forward(self, x: torch.Tensor, *args: any, **kwargs: any): return torch.cat([ torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) ], dim=-1).clamp(min=self.eps) class SoftmaxDimHalfspace(FeatureMapAct): """ Softmax activation as in https://arxiv.org/abs/2402.04347 """ def forward(self, x: torch.Tensor, *args: any, **kwargs: any): return torch.softmax(x, dim=-1).clamp(min=self.eps) class Exp(FeatureMapAct): """ Exp activation as in https://arxiv.org/abs/2402.04347 """ def forward(self, x: torch.Tensor, *args: any, **kwargs: any): x_max = torch.amax(x, dim=-1, keepdim=True) x_min = torch.amin(x, dim=-1, keepdim=True) return torch.cat([ torch.exp(x - x_max), torch.exp(-x + x_min) ], dim=-1).clamp(min=self.eps) class ExpHalfspace(FeatureMapAct): """ Exp activation as in https://arxiv.org/abs/2402.04347 """ def forward(self, x: torch.Tensor, *args: any, **kwargs: any): x_max = torch.amax(x, dim=-1, keepdim=True) return torch.exp(x - x_max).clamp(min=self.eps) # ---------------- # Feature map MLPs # ---------------- class FeatureMapMLP(nn.Module): """ Learnable MLP in feature map. Full feature map is like f(xW + b) -> This is the `W` and (optional) `b` part """ def __init__(self, num_heads: int, head_dim: int, # input dim feature_dim: int, # output dim dtype: torch.dtype, device: torch.device, skip_connection: bool = False, bias: bool = False, zero_init: bool = False, normal_init: bool = False,): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.feature_dim = feature_dim self.dtype = dtype self.device = device self.skip_connection = skip_connection self.bias = bias self.zero_init = zero_init self.normal_init = normal_init self.init_weights_() if self.zero_init: # Zero-out weights or set as identity post-initialization self.zero_init_with_skip_() if self.skip_connection else self.zero_init_() if self.normal_init: with torch.no_grad(): nn.init.normal_(self.layer) if self.skip_connection: assertion_fail = f'If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}' assert self.head_dim == self.feature_dim, assertion_fail def init_weights_(self): """ Initialize (W)eights and (b)iases """ self.layer = nn.Parameter(torch.zeros( (self.num_heads, self.head_dim, self.feature_dim), dtype=self.dtype, device=self.device, )) nn.init.kaiming_uniform_(self.layer) if self.bias: self.bias = nn.Parameter(torch.zeros( (1, self.num_heads, 1, 1), # self.feature_dim), dtype=self.dtype, device=self.device, )) nn.init.kaiming_uniform_(self.bias) else: self.bias = 0. # hack def zero_init_with_skip_(self): """ Initialize weights to zero matrix if skip connection """ with torch.no_grad(): nn.init.zeros_(self.layer) def zero_init_(self): """ Initialize weights to identity matrix if no skip connection """ with torch.no_grad(): for i in range(self.layer.shape[0]): try: nn.init.eye_(self.layer[i]) except RuntimeError: with torch.no_grad(): dtype = self.layer[i].dtype weight = torch.eye(*self.layer[i].shape, requires_grad=self.layer[i].requires_grad, device=self.layer[i].device) self.layer[i] = weight.to(dtype=dtype) def forward(self, x: torch.Tensor): """ Assume x.shape is (batch_size, num_heads, seq_len, head_dim) """ _x = torch.einsum('hdf,bhld->bhlf', self.layer, x) + self.bias return x + _x if self.skip_connection else _x class FeatureMapAdapter(FeatureMapMLP): """ Learnable Feature map with bottleneck adapter as in https://arxiv.org/abs/1902.00751 We don't use but could be fun to try """ def __init__(self, hidden_dim: int, *args, **kwargs): kwargs['skip_connection'] = True kwargs['bias'] = True kwargs['zero_init'] = True self.hidden_dim = hidden_dim super().__init__(*args, **kwargs) def init_weights_(self): """ Initialize (W)eights and (b)iases """ kwargs = {'dtype': self.dtype, 'device': self.device} self.layer0 = nn.Parameter( torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs) ) self.layer1 = nn.Parameter( torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs) ) nn.init.kaiming_uniform_(self.layer0) nn.init.kaiming_uniform_(self.layer1) self.bias0 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)) self.bias1 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)) nn.init.kaiming_uniform_(self.bias0) nn.init.kaiming_uniform_(self.bias1) def zero_init_with_skip_(self): with torch.no_grad(): nn.init.zeros_(self.layer0) nn.init.zeros_(self.layer1) nn.init.zeros_(self.bias0) nn.init.zeros_(self.bias1) def zero_init_(self): assert NotImplementedError def forward(self, x: torch.Tensor): """ Assume x.shape is (batch_size, num_heads, seq_len, head_dim) -> Down-project, apply nonlinearity, up-project; add skip connection """ _x = torch.einsum('hde,bhld->bhle', self.layer0, x) + self.bias0 _x = F.relu(_x) _x = torch.einsum('hef,bhle->bhlf', self.layer1, _x) + self.bias1 return x + _x if self.skip_connection else _x