Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
from transformer import * | |
class Transformer(nn.Module): | |
def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len, | |
ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True): | |
super().__init__() | |
self.src_pad_idx = src_pad_idx | |
self.trg_pad_idx = trg_pad_idx | |
self.encoder = Encoder(d_model=d_model, | |
n_head=n_head, | |
max_len=max_len, | |
ffn_hidden=ffn_hidden, | |
enc_voc_size=enc_voc_size, | |
drop_prob=drop_prob, | |
n_layers=n_layers, | |
padding_idx=src_pad_idx, | |
learnable_pos_emb=learnable_pos_emb) | |
self.decoder = nn.Linear(d_model, dec_voc_size) | |
def get_device(self): | |
return next(self.parameters()).device | |
def forward(self, src): | |
device = self.get_device() | |
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device) | |
enc_src = self.encoder(src, src_mask) | |
output = self.decoder(enc_src) | |
return output | |
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx): | |
len_q, len_k = q.size(1), k.size(1) | |
# batch_size x 1 x 1 x len_k | |
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2) | |
# batch_size x 1 x len_q x len_k | |
k = k.repeat(1, 1, len_q, 1) | |
# batch_size x 1 x len_q x 1 | |
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3) | |
# batch_size x 1 x len_q x len_k | |
q = q.repeat(1, 1, 1, len_k) | |
mask = k & q | |
return mask | |
def make_no_peak_mask(self, q, k): | |
len_q, len_k = q.size(1), k.size(1) | |
# len_q x len_k | |
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor) | |
return mask | |