|
import torch |
|
from torch import Tensor |
|
|
|
from transformer import Transformer |
|
from tokenizers import Tokenizer |
|
from dataset import causal_mask |
|
|
|
|
|
def greedy_decode( |
|
model: Transformer, |
|
src: Tensor, |
|
src_mask: Tensor, |
|
src_tokenizer: Tokenizer, |
|
tgt_tokenizer: Tokenizer, |
|
tgt_max_seq_len: int, |
|
device, |
|
give_attn: bool = False, |
|
): |
|
""" |
|
Decodes greedily. |
|
""" |
|
sos_idx = src_tokenizer.token_to_id('<sos>') |
|
eos_idx = src_tokenizer.token_to_id('<eos>') |
|
|
|
encoder_output = model.encode(src, src_mask) |
|
|
|
attn = None |
|
decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device) |
|
|
|
while True: |
|
if decoder_input.size(1) == tgt_max_seq_len: |
|
break |
|
|
|
|
|
decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device) |
|
|
|
|
|
decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask) |
|
|
|
prob = model.project(decoder_output[:, -1]) |
|
_, next_word = torch.max(prob, dim=1) |
|
decoder_input = torch.cat( |
|
[decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1 |
|
) |
|
|
|
if next_word == eos_idx: |
|
break |
|
if give_attn: |
|
return (decoder_input.squeeze(0), attn) |
|
return decoder_input.squeeze(0) |