|
import torch |
|
from torch import nn |
|
|
|
|
|
class AdaptiveEmbedding(nn.Module): |
|
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False): |
|
super().__init__() |
|
|
|
self.n_token = n_token |
|
self.d_embed = d_embed |
|
|
|
self.cutoffs = cutoffs + [n_token] |
|
self.div_val = div_val |
|
self.d_proj = d_proj |
|
|
|
self.emb_scale = d_proj**0.5 |
|
|
|
self.cutoff_ends = [0] + self.cutoffs |
|
|
|
self.emb_layers = nn.ModuleList() |
|
self.emb_projs = nn.ParameterList() |
|
if div_val == 1: |
|
self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0)) |
|
if d_proj != d_embed: |
|
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) |
|
else: |
|
for i in range(len(self.cutoffs)): |
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] |
|
d_emb_i = d_embed // (div_val**i) |
|
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) |
|
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) |
|
|
|
def forward(self, inp): |
|
if self.div_val == 1: |
|
embed = self.emb_layers[0](inp) |
|
if self.d_proj != self.d_embed: |
|
embed = nn.functional.linear(embed, self.emb_projs[0]) |
|
else: |
|
param = next(self.parameters()) |
|
inp_flat = inp.view(-1) |
|
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) |
|
for i in range(len(self.cutoffs)): |
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] |
|
|
|
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) |
|
indices_i = mask_i.nonzero().squeeze() |
|
|
|
if indices_i.numel() == 0: |
|
continue |
|
|
|
inp_i = inp_flat.index_select(0, indices_i) - l_idx |
|
emb_i = self.emb_layers[i](inp_i) |
|
emb_i = nn.functional.linear(emb_i, self.emb_projs[i]) |
|
|
|
emb_flat.index_copy_(0, indices_i, emb_i) |
|
|
|
embed_shape = inp.size() + (self.d_proj,) |
|
embed = emb_flat.view(embed_shape) |
|
|
|
embed.mul_(self.emb_scale) |
|
|
|
return embed |
|
|
|
|
|
class PositionalEmbeddingAux(nn.Module): |
|
def __init__(self, demb): |
|
super().__init__() |
|
|
|
self.demb = demb |
|
|
|
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
def forward(self, pos_seq, bsz=None): |
|
sinusoid_inp = torch.outer(pos_seq, self.inv_freq) |
|
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) |
|
|
|
if bsz is not None: |
|
return pos_emb[:, None, :].expand(-1, bsz, -1) |
|
else: |
|
return pos_emb[:, None, :] |
|
|
|
|
|
class PositionalEmbedding(PositionalEmbeddingAux): |
|
def forward(self, pos_seq, bsz=None): |
|
return super().forward(pos_seq.squeeze(0), bsz=bsz).squeeze(1) |
|
|