homemade_lo_vi / decode_method.py
moiduy04's picture
Upload 4 files
27f7f75
raw
history blame
1.37 kB
import torch
from torch import Tensor
from transformer import Transformer
from tokenizers import Tokenizer
from dataset import causal_mask
def greedy_decode(
model: Transformer,
src: Tensor,
src_mask: Tensor,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
tgt_max_seq_len: int,
device,
give_attn: bool = False,
):
"""
Decodes greedily.
"""
sos_idx = src_tokenizer.token_to_id('<sos>')
eos_idx = src_tokenizer.token_to_id('<eos>')
encoder_output = model.encode(src, src_mask)
attn = None
decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
while True:
if decoder_input.size(1) == tgt_max_seq_len:
break
# build target mask
decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device)
# get decoder output
decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask)
prob = model.project(decoder_output[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1
)
if next_word == eos_idx:
break
if give_attn:
return (decoder_input.squeeze(0), attn)
return decoder_input.squeeze(0)