|
"""positional_encoding.py """ |
|
from typing import Optional, Literal |
|
from inspect import isfunction |
|
from math import log, log2, pi, floor |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, einsum |
|
from einops import rearrange, repeat |
|
|
|
from model.RoPE.RoPE import RotaryEmbedding |
|
|
|
|
|
class AlibiPositionalBias(nn.Module): |
|
""" |
|
Alibi Positional Bias for Transformer Attention |
|
: modified to support trainalbe slope similar to "little bird" paper, based on |
|
https://github.com/lucidrains/x-transformers/ |
|
https://github.com/ofirpress/attention_with_linear_biases/issues/5 |
|
|
|
This is Alibi positional bias extension for: |
|
- bi-directional self/cross attention |
|
- supporting extrapolation. |
|
|
|
References: |
|
Ofir, Noah A. Smith, and Mike Lewis. "Train short, test long: Attention with linear |
|
biases enables input length extrapolation." arXiv preprint arXiv:2108.12409 (2021). |
|
|
|
Lee, Minchul, Kijong Han, and Myeong Cheol Shin. "LittleBird: Efficient Faster & Longer |
|
Transformer for Question Answering." arXiv preprint arXiv:2210.11870 (2022). |
|
""" |
|
|
|
def __init__(self, |
|
heads: int = 8, |
|
total_heads: int = 8, |
|
trainable_slope: bool = False, |
|
trainable_slope_init: Literal['random', 'log'] = 'random', |
|
**kwargs) -> None: |
|
super().__init__() |
|
self.heads = heads |
|
self.total_heads = total_heads |
|
self.trainable_slope = trainable_slope |
|
self.trainable_slope_init = trainable_slope_init |
|
|
|
if trainable_slope: |
|
self.slopes = nn.Parameter(torch.Tensor(heads, 1, 1), requires_grad=True) |
|
else: |
|
slopes = torch.Tensor(self._get_slopes(heads)) |
|
slopes = rearrange(slopes, 'h -> h 1 1') |
|
self.register_buffer('slopes', slopes, persistent=False) |
|
|
|
self.register_buffer('bias', None, persistent=False) |
|
|
|
def reset_parameters(self) -> None: |
|
if self.trainable_slope: |
|
if self.trainable_slope_init == 'random': |
|
nn.init.normal_(self.slopes, -2, 1) |
|
else: |
|
raise NotImplementedError(f'Unknown trainable_slope_init: {self.trainable_slope_init}') |
|
|
|
def get_bias(self, i, j, device): |
|
i_arange = torch.arange(j - i, j, device=device) |
|
j_arange = torch.arange(j, device=device) |
|
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) |
|
return bias |
|
|
|
@staticmethod |
|
def _get_slopes(heads): |
|
|
|
def get_slopes_power_of_2(n): |
|
start = (2**(-2**-(log2(n) - 3))) |
|
ratio = start |
|
return [start * ratio**i for i in range(n)] |
|
|
|
if log2(heads).is_integer(): |
|
return get_slopes_power_of_2(heads) |
|
|
|
closest_power_of_2 = 2**floor(log2(heads)) |
|
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2( |
|
2 * closest_power_of_2)[0::2][:heads - closest_power_of_2] |
|
|
|
@staticmethod |
|
def pad_at_dim(t, pad, dim=-1, value=0.): |
|
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) |
|
zeros = ((0, 0) * dims_from_right) |
|
return F.pad(t, (*zeros, *pad), value=value) |
|
|
|
@property |
|
def device(self): |
|
if self.trainable_slope: |
|
return self.slopes.device |
|
else: |
|
return next(self.buffers()).device |
|
|
|
def forward(self, i, j): |
|
""" |
|
Args: |
|
i (int): end index of query |
|
j (int): end index of key |
|
|
|
Returns: |
|
torch.Tensor: (num_total_heads, i, j) positional bias for each head |
|
|
|
Usage: |
|
>>> alibi_bias = AlibiPositionalBias(heads=8, total_heads=8, trainable_slope=False) |
|
>>> pos_bias = alibi_bias(len(q), len(k)) |
|
>>> q_dot_k = ... |
|
>>> q_dot_k += pos_bias |
|
>>> q_dot_k = q_dot_k.softmax(dim=-1) |
|
|
|
""" |
|
h, device = self.total_heads, self.device |
|
if self.trainable_slope: |
|
if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: |
|
bias = self.bias[..., :i, :j] |
|
else: |
|
bias = self.get_bias(i, j, device) |
|
num_heads_unalibied = h - bias.shape[0] |
|
bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0) |
|
self.register_buffer('bias', bias, persistent=False) |
|
|
|
return self.bias * torch.sigmoid(self.slopes) |
|
|
|
else: |
|
if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: |
|
return self.bias[..., :i, :j] |
|
|
|
bias = self.get_bias(i, j, device) |
|
bias = bias * self.slopes |
|
|
|
num_heads_unalibied = h - bias.shape[0] |
|
bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0) |
|
self.register_buffer('bias', bias, persistent=False) |
|
|
|
return self.bias |
|
|
|
|
|
class FixedSinusoidalPositionalEmbedding(nn.Embedding): |
|
""" |
|
Sinusoidal Absolute Positional Embeddings (APE) of any length. |
|
|
|
Adapted from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding |
|
|
|
""" |
|
|
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): |
|
super().__init__(num_positions, embedding_dim) |
|
self.weight = self._init_weight(self.weight) |
|
|
|
@staticmethod |
|
def _init_weight(out: nn.Parameter): |
|
""" |
|
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in |
|
the 2nd half of the vector. [dim // 2:] |
|
""" |
|
n_pos, dim = out.shape |
|
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos) |
|
]) |
|
out.requires_grad = False |
|
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 |
|
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) |
|
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) |
|
out.detach_() |
|
return out |
|
|
|
@torch.no_grad() |
|
def forward(self, seq_len: int, past_key_values_length: int = 0): |
|
"""`input_ids_shape` is expected to be [bsz x seqlen].""" |
|
positions = torch.arange(past_key_values_length, |
|
past_key_values_length + seq_len, |
|
dtype=torch.long, |
|
device=self.weight.device) |
|
return super().forward(positions) |
|
|
|
|
|
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): |
|
"""Rotary positional embedding |
|
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
dim = config.d_model // config.num_heads |
|
base = config.rotary_embedding_base |
|
|
|
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
self.cached_sequence_length = None |
|
self.cached_rotary_positional_embedding = None |
|
|
|
def forward(self, hidden_states): |
|
sequence_length = hidden_states.shape[1] |
|
|
|
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: |
|
return self.cached_rotary_positional_embedding |
|
|
|
self.cached_sequence_length = sequence_length |
|
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) |
|
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) |
|
embeddings = torch.cat((freqs, freqs), dim=-1) |
|
|
|
cos_embeddings = embeddings.cos()[:, None, None, :] |
|
sin_embeddings = embeddings.sin()[:, None, None, :] |
|
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]) |
|
return self.cached_rotary_positional_embedding |
|
|
|
|
|
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): |
|
"""Relative positional encoding module.""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.max_len = config.num_max_positions |
|
self.d_model = config.d_model |
|
self.pe = None |
|
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) |
|
|
|
def extend_pe(self, x): |
|
|
|
if self.pe is not None: |
|
|
|
|
|
if self.pe.size(1) >= x.size(1) * 2 - 1: |
|
if self.pe.dtype != x.dtype or self.pe.device != x.device: |
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
|
return |
|
|
|
|
|
|
|
pe_positive = torch.zeros(x.size(1), self.d_model) |
|
pe_negative = torch.zeros(x.size(1), self.d_model) |
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(log(10000.0) / self.d_model)) |
|
pe_positive[:, 0::2] = torch.sin(position * div_term) |
|
pe_positive[:, 1::2] = torch.cos(position * div_term) |
|
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
|
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
|
|
|
|
|
|
|
|
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
|
pe_negative = pe_negative[1:].unsqueeze(0) |
|
pe = torch.cat([pe_positive, pe_negative], dim=1) |
|
self.pe = pe.to(device=x.device, dtype=x.dtype) |
|
|
|
def forward(self, hidden_states: torch.Tensor): |
|
self.extend_pe(hidden_states) |
|
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1 |
|
end_idx = self.pe.size(1) // 2 + hidden_states.size(1) |
|
relative_position_embeddings = self.pe[:, start_idx:end_idx] |
|
|
|
return relative_position_embeddings |
|
|
|
|
|
|
|
|
|
|
|
def get_rotary_emb(d_by_head: int, |
|
freqs_for: Literal["l", "lang", "p", "pixel"], |
|
partial_pe: bool = False, |
|
learned_freq: bool = False): |
|
if partial_pe is True: |
|
rdim = d_by_head // 2 |
|
else: |
|
rdim = d_by_head |
|
|
|
if freqs_for in ["l", "lang"]: |
|
freqs_for = "lang" |
|
elif freqs_for in ["p", "pixel"]: |
|
freqs_for = "pixel" |
|
else: |
|
raise ValueError(f"freqs_for must be 'l' or 'lang' or 'p' or 'pixel', but got {freqs_for}") |
|
return RotaryEmbedding(dim=rdim, freqs_for=freqs_for, learned_freq=learned_freq) |
|
|
|
|
|
def test_rotary_embedding_lang(): |
|
d = 128 |
|
num_heads = 8 |
|
d_by_head = d // num_heads |
|
|
|
rotary = get_rotary_emb(d_by_head, freqs_for="lang", partial_pe=False, learned_freq=False) |
|
q = torch.ones(1, 8, 110, d_by_head) |
|
q = rotary.apply_rotary_custom(q) |
|
|
|
import matplotlib.pyplot as plt |
|
plt.imshow(q[0, 0, :, :].detach().numpy().T, origin='lower') |
|
|