File size: 2,204 Bytes
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Tuple

import torch
from torch import Tensor

from tokenizers import Tokenizer

from transformer import Transformer
from decode_method import greedy_decode

def translate(
    model: Transformer,
    src_tokenizer: Tokenizer,
    tgt_tokenizer: Tokenizer,
    text: str,
    decode_method: str = 'greedy',
    device = torch.device('cpu')
) -> Tuple[str, Tensor]:
    """
    Translation function.

    Output:
        - translation (str): the translated string.
        - attn (Tensor): The decoder's attention (for visualization)
    """
    sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
    eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
    pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
    
    encoder_input_tokens = src_tokenizer.encode(text).ids
    # <sos> + source_text + <eos> = encoder_input
    encoder_input = torch.cat(
        [
            sos_token,
            torch.tensor(encoder_input_tokens, dtype=torch.int64),
            eos_token,
        ]
    )
    encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)

    encoder_input = encoder_input.unsqueeze(0)
    # encoder_mask = torch.tensor(encoder_mask)
    
    assert encoder_input.size(0) == 1

    if decode_method == 'greedy':
        model_out, attn = greedy_decode(
            model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 400, device, 
            give_attn=True,
        )
    elif decode_method == 'beam-search':
        raise NotImplementedError
    else:
        raise ValueError("Unsuppored decode method")

    model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
    return model_out_text, attn
    

from config import load_config
from load_and_save_model import load_model_tokenizer
if __name__ == '__main__':
    config = load_config(file_name='config_small.yaml')
    model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
    text = "ສະບາຍດີ" # Hello.
    translation, attn = translate(
        model, src_tokenizer, tgt_tokenizer, text
    )
    print(translation)