File size: 2,204 Bytes
bc1ada8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from typing import Tuple
import torch
from torch import Tensor
from tokenizers import Tokenizer
from transformer import Transformer
from decode_method import greedy_decode
def translate(
model: Transformer,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
text: str,
decode_method: str = 'greedy',
device = torch.device('cpu')
) -> Tuple[str, Tensor]:
"""
Translation function.
Output:
- translation (str): the translated string.
- attn (Tensor): The decoder's attention (for visualization)
"""
sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
encoder_input_tokens = src_tokenizer.encode(text).ids
# <sos> + source_text + <eos> = encoder_input
encoder_input = torch.cat(
[
sos_token,
torch.tensor(encoder_input_tokens, dtype=torch.int64),
eos_token,
]
)
encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)
encoder_input = encoder_input.unsqueeze(0)
# encoder_mask = torch.tensor(encoder_mask)
assert encoder_input.size(0) == 1
if decode_method == 'greedy':
model_out, attn = greedy_decode(
model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 400, device,
give_attn=True,
)
elif decode_method == 'beam-search':
raise NotImplementedError
else:
raise ValueError("Unsuppored decode method")
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
return model_out_text, attn
from config import load_config
from load_and_save_model import load_model_tokenizer
if __name__ == '__main__':
config = load_config(file_name='config_small.yaml')
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
text = "ສະບາຍດີ" # Hello.
translation, attn = translate(
model, src_tokenizer, tgt_tokenizer, text
)
print(translation) |