Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,644 Bytes
bcc0c7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch.nn as nn
from transformer import *
class Transformer(nn.Module):
def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True):
super().__init__()
self.src_pad_idx = src_pad_idx
self.trg_pad_idx = trg_pad_idx
self.encoder = Encoder(d_model=d_model,
n_head=n_head,
max_len=max_len,
ffn_hidden=ffn_hidden,
enc_voc_size=enc_voc_size,
drop_prob=drop_prob,
n_layers=n_layers,
padding_idx=src_pad_idx,
learnable_pos_emb=learnable_pos_emb)
self.decoder = Decoder(d_model=d_model,
n_head=n_head,
max_len=max_len,
ffn_hidden=ffn_hidden,
dec_voc_size=dec_voc_size,
drop_prob=drop_prob,
n_layers=n_layers,
padding_idx=trg_pad_idx,
learnable_pos_emb=learnable_pos_emb)
def get_device(self):
return next(self.parameters()).device
def forward(self, src, trg):
device = self.get_device()
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device)
src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device)
trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \
self.make_no_peak_mask(trg, trg).to(device)
enc_src = self.encoder(src, src_mask)
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
return output
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
len_q, len_k = q.size(1), k.size(1)
# batch_size x 1 x 1 x len_k
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
# batch_size x 1 x len_q x len_k
k = k.repeat(1, 1, len_q, 1)
# batch_size x 1 x len_q x 1
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
# batch_size x 1 x len_q x len_k
q = q.repeat(1, 1, 1, len_k)
mask = k & q
return mask
def make_no_peak_mask(self, q, k):
len_q, len_k = q.size(1), k.size(1)
# len_q x len_k
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor)
return mask
|