Upload 12 files
Browse files- bleu.py +49 -0
- bleu_test.py +34 -0
- dataset.py +99 -0
- decode_method.py +66 -1
- load_and_save_model.py +73 -0
- load_dataset.py +104 -0
- tokenizer.py +36 -0
- train.py +199 -0
- translate.py +42 -32
- utils.py +6 -0
- validate.py +72 -0
bleu.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torchtext.data.metrics import bleu_score
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from decode_method import beam_search_decode
|
8 |
+
from transformer import Transformer
|
9 |
+
|
10 |
+
from tokenizers import Tokenizer
|
11 |
+
|
12 |
+
|
13 |
+
def calculate_bleu_score(
|
14 |
+
model: Transformer,
|
15 |
+
bleu_dataloader: DataLoader,
|
16 |
+
src_tokenizer: Tokenizer,
|
17 |
+
tgt_tokenizer: Tokenizer,
|
18 |
+
device = torch.device('cpu'),
|
19 |
+
num_samples: int = 9999999,
|
20 |
+
):
|
21 |
+
""""""
|
22 |
+
model.eval()
|
23 |
+
|
24 |
+
# inferance
|
25 |
+
count = 0
|
26 |
+
expected = []
|
27 |
+
predicted = []
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
batch_iterator = tqdm(bleu_dataloader)
|
31 |
+
for batch in batch_iterator:
|
32 |
+
count += 1
|
33 |
+
encoder_input = batch['encoder_input'].to(device)
|
34 |
+
encoder_mask = batch['encoder_mask'].to(device)
|
35 |
+
|
36 |
+
assert encoder_input.size(0) == 1, "batch_size = 1 for bleu calculation"
|
37 |
+
|
38 |
+
model_out = beam_search_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
|
39 |
+
|
40 |
+
target_text = batch['tgt_text'][0]
|
41 |
+
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
|
42 |
+
|
43 |
+
expected.append([target_text.split()])
|
44 |
+
predicted.append(model_out_text.split())
|
45 |
+
|
46 |
+
if count == num_samples:
|
47 |
+
break
|
48 |
+
|
49 |
+
return bleu_score(predicted, expected) * 100.0
|
bleu_test.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from bleu import calculate_bleu_score
|
3 |
+
from load_dataset import load_local_bleu_dataset
|
4 |
+
from dataset import BilingualDataset
|
5 |
+
from config import load_config
|
6 |
+
from load_and_save_model import load_model_tokenizer
|
7 |
+
|
8 |
+
|
9 |
+
def get_bleu_of_model(config) -> float:
|
10 |
+
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
|
11 |
+
bleu_ds_raw = load_local_bleu_dataset(
|
12 |
+
src_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['src_lang'],
|
13 |
+
tgt_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['tgt_lang'],
|
14 |
+
src_lang=config['dataset']['src_lang'],
|
15 |
+
tgt_lang=config['dataset']['tgt_lang'],
|
16 |
+
)
|
17 |
+
bleu_ds = BilingualDataset(
|
18 |
+
ds=bleu_ds_raw,
|
19 |
+
src_tokenizer=src_tokenizer,
|
20 |
+
tgt_tokenizer=tgt_tokenizer,
|
21 |
+
src_lang=config['dataset']['src_lang'],
|
22 |
+
tgt_lang=config['dataset']['tgt_lang'],
|
23 |
+
src_max_seq_len=config['dataset']['src_max_seq_len'],
|
24 |
+
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
|
25 |
+
)
|
26 |
+
bleu_dataloader = DataLoader(bleu_ds, batch_size=1, shuffle=True)
|
27 |
+
return calculate_bleu_score(
|
28 |
+
model, bleu_dataloader, src_tokenizer, tgt_tokenizer,
|
29 |
+
)
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
for file_name in {'config_final.yaml', 'config_huge.yaml', 'config_big.yaml', 'config_small.yaml'}:
|
33 |
+
config = load_config(file_name)
|
34 |
+
print(get_bleu_of_model(config), f" is the BLEU of {file_name}", sep='')
|
dataset.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from tokenizers import Tokenizer
|
9 |
+
|
10 |
+
class BilingualDataset(Dataset):
|
11 |
+
"""
|
12 |
+
A Bilingual Dataset that follows the structure of the 'opus_books' dataset.
|
13 |
+
"""
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
ds: List[Dict[str, Dict[str,str]]],
|
17 |
+
src_tokenizer: Tokenizer,
|
18 |
+
tgt_tokenizer: Tokenizer,
|
19 |
+
src_lang: str,
|
20 |
+
tgt_lang: str,
|
21 |
+
src_max_seq_len: int,
|
22 |
+
tgt_max_seq_len: int,
|
23 |
+
) -> None:
|
24 |
+
super(BilingualDataset, self).__init__()
|
25 |
+
|
26 |
+
self.ds = ds
|
27 |
+
self.src_tokenizer = src_tokenizer
|
28 |
+
self.tgt_tokenizer = tgt_tokenizer
|
29 |
+
self.src_lang = src_lang
|
30 |
+
self.tgt_lang = tgt_lang
|
31 |
+
|
32 |
+
self.src_max_seq_len = src_max_seq_len
|
33 |
+
self.tgt_max_seq_len = tgt_max_seq_len
|
34 |
+
|
35 |
+
self.sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
|
36 |
+
self.eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
|
37 |
+
self.pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return len(self.ds)
|
41 |
+
|
42 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
43 |
+
src_tgt_pair = self.ds[index]
|
44 |
+
src_text = src_tgt_pair['translation'][self.src_lang]
|
45 |
+
tgt_text = src_tgt_pair['translation'][self.tgt_lang]
|
46 |
+
|
47 |
+
encoder_input_tokens = self.src_tokenizer.encode(src_text).ids
|
48 |
+
decoder_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids
|
49 |
+
|
50 |
+
encoder_num_padding = self.src_max_seq_len - len(encoder_input_tokens) - 2 # <sos> + <eos>
|
51 |
+
decoder_num_padding = self.tgt_max_seq_len - len(decoder_input_tokens) - 1 # <sos>
|
52 |
+
|
53 |
+
# <sos> + source_text + <eos> + <pad> = encoder_input
|
54 |
+
encoder_input = torch.cat(
|
55 |
+
[
|
56 |
+
self.sos_token,
|
57 |
+
torch.tensor(encoder_input_tokens, dtype=torch.int64),
|
58 |
+
self.eos_token,
|
59 |
+
torch.tensor([self.pad_token] * encoder_num_padding, dtype=torch.int64)
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
decoder_input_tokens = torch.tensor(decoder_input_tokens, dtype=torch.int64)
|
64 |
+
decoder_padding = torch.tensor([self.pad_token] * decoder_num_padding, dtype=torch.int64)
|
65 |
+
# <sos> + target_text + <pad> = decoder_input
|
66 |
+
decoder_input = torch.cat(
|
67 |
+
[
|
68 |
+
self.sos_token,
|
69 |
+
decoder_input_tokens,
|
70 |
+
decoder_padding
|
71 |
+
]
|
72 |
+
)
|
73 |
+
# target_text + <eos> + <pad> = expected decoder_output (label)
|
74 |
+
label = torch.cat(
|
75 |
+
[
|
76 |
+
decoder_input_tokens,
|
77 |
+
self.eos_token,
|
78 |
+
decoder_padding
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
assert encoder_input.size(0) == self.src_max_seq_len
|
83 |
+
assert decoder_input.size(0) == self.tgt_max_seq_len
|
84 |
+
assert label.size(0) == self.tgt_max_seq_len
|
85 |
+
|
86 |
+
return {
|
87 |
+
'encoder_input': encoder_input, # (seq_len)
|
88 |
+
'decoder_input': decoder_input, # (seq_len)
|
89 |
+
'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
|
90 |
+
'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len, seq_len)
|
91 |
+
'label': label, # (seq_len)
|
92 |
+
'src_text': src_text,
|
93 |
+
'tgt_text': tgt_text,
|
94 |
+
}
|
95 |
+
|
96 |
+
def causal_mask(size: int) -> Tensor:
|
97 |
+
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
|
98 |
+
return mask == 0
|
99 |
+
|
decode_method.py
CHANGED
@@ -47,4 +47,69 @@ def greedy_decode(
|
|
47 |
break
|
48 |
if give_attn:
|
49 |
return (decoder_input.squeeze(0), attn)
|
50 |
-
return decoder_input.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
break
|
48 |
if give_attn:
|
49 |
return (decoder_input.squeeze(0), attn)
|
50 |
+
return decoder_input.squeeze(0)
|
51 |
+
|
52 |
+
def beam_search_decode(
|
53 |
+
model: Transformer,
|
54 |
+
src: Tensor,
|
55 |
+
src_mask: Tensor,
|
56 |
+
src_tokenizer: Tokenizer,
|
57 |
+
tgt_tokenizer: Tokenizer,
|
58 |
+
tgt_max_seq_len: int,
|
59 |
+
device,
|
60 |
+
beam_size: int = 3,
|
61 |
+
):
|
62 |
+
sos_idx = src_tokenizer.token_to_id('<sos>')
|
63 |
+
eos_idx = src_tokenizer.token_to_id('<eos>')
|
64 |
+
|
65 |
+
# Precompute the encoder output and reuse it for every step
|
66 |
+
encoder_output = model.encode(src, src_mask)
|
67 |
+
# Initialize the decoder input with the sos token
|
68 |
+
decoder_initial_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
|
69 |
+
|
70 |
+
# Create a candidate list
|
71 |
+
candidates = [(decoder_initial_input, 1)]
|
72 |
+
|
73 |
+
while True:
|
74 |
+
|
75 |
+
# If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
|
76 |
+
if any([cand.size(1) == tgt_max_seq_len for cand, _ in candidates]):
|
77 |
+
break
|
78 |
+
|
79 |
+
# Create a new list of candidates
|
80 |
+
new_candidates = []
|
81 |
+
|
82 |
+
for candidate, score in candidates:
|
83 |
+
|
84 |
+
# Do not expand candidates that have reached the eos token
|
85 |
+
if candidate[0][-1].item() == eos_idx:
|
86 |
+
continue
|
87 |
+
|
88 |
+
# Build the candidate's mask
|
89 |
+
candidate_mask = causal_mask(candidate.size(1)).type_as(src_mask).to(device)
|
90 |
+
# calculate output
|
91 |
+
out, attn = model.decode(encoder_output, src_mask, candidate, candidate_mask)
|
92 |
+
# get next token probabilities
|
93 |
+
prob = model.project(out[:, -1])
|
94 |
+
# get the top k candidates
|
95 |
+
topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
|
96 |
+
for i in range(beam_size):
|
97 |
+
# for each of the top k candidates, get the token and its probability
|
98 |
+
token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
|
99 |
+
token_prob = topk_prob[0][i].item()
|
100 |
+
# create a new candidate by appending the token to the current candidate
|
101 |
+
new_candidate = torch.cat([candidate, token], dim=1)
|
102 |
+
# We sum the log probabilities because the probabilities are in log space
|
103 |
+
new_candidates.append((new_candidate, score + token_prob))
|
104 |
+
|
105 |
+
# Sort the new candidates by their score
|
106 |
+
candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
|
107 |
+
# Keep only the top k candidates
|
108 |
+
candidates = candidates[:beam_size]
|
109 |
+
|
110 |
+
# If all the candidates have reached the eos token, stop
|
111 |
+
if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
|
112 |
+
break
|
113 |
+
|
114 |
+
# Return the best candidate
|
115 |
+
return candidates[0][0].squeeze()
|
load_and_save_model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformer import get_model, Transformer
|
5 |
+
from config import load_config, get_weights_file_path
|
6 |
+
from train import get_local_dataset_tokenizer
|
7 |
+
from tokenizer import get_or_build_local_tokenizer
|
8 |
+
|
9 |
+
from tokenizers import Tokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def load_train_data_and_save_model(config, model_name):
|
13 |
+
"""
|
14 |
+
loads training data (model, optim, scheduler,...) and saves ONLY the model.
|
15 |
+
"""
|
16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
print(f'Using device {device}')
|
18 |
+
|
19 |
+
train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config)
|
20 |
+
model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device)
|
21 |
+
|
22 |
+
assert config['model']['preload'], 'where to preload model.'
|
23 |
+
|
24 |
+
model_load_filename = get_weights_file_path(config, config['model']['preload'])
|
25 |
+
print(f'Preloading model from train data in {model_load_filename}')
|
26 |
+
state = torch.load(model_load_filename, map_location=device)
|
27 |
+
|
28 |
+
model.load_state_dict(state['model_state_dict'])
|
29 |
+
|
30 |
+
model_save_filename = get_weights_file_path(config, model_name)
|
31 |
+
torch.save(model.state_dict(), model_save_filename)
|
32 |
+
print(f'Model saved at {model_save_filename}')
|
33 |
+
|
34 |
+
def load_model_tokenizer(
|
35 |
+
config,
|
36 |
+
device = torch.device('cpu'),
|
37 |
+
logs: bool = True,
|
38 |
+
) -> Tuple[Transformer, Tokenizer, Tokenizer]:
|
39 |
+
"""
|
40 |
+
Loads a local model and tokenizer from a given config
|
41 |
+
"""
|
42 |
+
if config['model']['preload'] is None:
|
43 |
+
raise ValueError('Unspecified preload model')
|
44 |
+
|
45 |
+
src_tokenizer = get_or_build_local_tokenizer(
|
46 |
+
config=config,
|
47 |
+
ds=None,
|
48 |
+
lang=config['dataset']['src_lang'],
|
49 |
+
tokenizer_type=config['dataset']['src_tokenizer']
|
50 |
+
)
|
51 |
+
tgt_tokenizer = get_or_build_local_tokenizer(
|
52 |
+
config=config,
|
53 |
+
ds=None,
|
54 |
+
lang=config['dataset']['tgt_lang'],
|
55 |
+
tokenizer_type=config['dataset']['tgt_tokenizer']
|
56 |
+
)
|
57 |
+
|
58 |
+
model = get_model(
|
59 |
+
config,
|
60 |
+
src_tokenizer.get_vocab_size(),
|
61 |
+
tgt_tokenizer.get_vocab_size(),
|
62 |
+
).to(device)
|
63 |
+
|
64 |
+
model_filename = get_weights_file_path(config, config['model']['preload'])
|
65 |
+
model.load_state_dict(
|
66 |
+
torch.load(model_filename, map_location=device)
|
67 |
+
)
|
68 |
+
print(f'Finish loading model and tokenizers')
|
69 |
+
return (model, src_tokenizer, tgt_tokenizer)
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
config = load_config(file_name='config_huge.yaml')
|
73 |
+
load_train_data_and_save_model(config, 'huge')
|
load_dataset.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from utils import get_full_file_path
|
6 |
+
|
7 |
+
# SENTENCE_STOPPERS = {'!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'}
|
8 |
+
# VIETNAMESE_SPECIAL_CHARACTERS = {'à', 'á', 'ả', 'ã', 'ạ', 'â', 'ầ', 'ấ', 'ẩ', 'ẫ', 'ậ', 'ă', 'ằ', 'ắ', 'ẳ', 'ẵ', 'ặ', 'è', 'é', 'ẻ', 'ẽ', 'ẹ', 'ê', 'ề', 'ế', 'ể', 'ễ', 'ệ', 'ì', 'í', 'ỉ', 'ĩ', 'ị', 'ò', 'ó', 'ỏ', 'õ', 'ọ', 'ô', 'ồ', 'ố', 'ổ', 'ỗ', 'ộ', 'ơ', 'ờ', 'ớ', 'ở', 'ỡ', 'ợ', 'ù', 'ú', 'ủ', 'ũ', 'ụ', 'ư', 'ừ', 'ứ', 'ử', 'ữ', 'ự', 'ỳ', 'ý', 'ỷ', 'ỹ', 'ỵ'}
|
9 |
+
|
10 |
+
# def is_Vietnamese_character(char):
|
11 |
+
# return char.isalpha() or char in VIETNAMESE_SPECIAL_CHARACTERS
|
12 |
+
|
13 |
+
# def categorize_word(word: str) -> str:
|
14 |
+
# """
|
15 |
+
# Categoize word into 3 types:
|
16 |
+
# - "vi": likely Vietnamese.
|
17 |
+
# - "lo": likely Laos.
|
18 |
+
# - "num": a number
|
19 |
+
# """
|
20 |
+
# if any(char.isdigit() for char in word):
|
21 |
+
# return "num"
|
22 |
+
|
23 |
+
# for stopper in SENTENCE_STOPPERS:
|
24 |
+
# if word.endswith(stopper):
|
25 |
+
# word = word[:-1]
|
26 |
+
# if len(word) == 0:
|
27 |
+
# break
|
28 |
+
|
29 |
+
# if len(word) > 0 and any(not is_Vietnamese_character(char) for char in word):
|
30 |
+
# return "lo"
|
31 |
+
# else:
|
32 |
+
# return "vi"
|
33 |
+
#
|
34 |
+
# def open_dataset(
|
35 |
+
# dataset_filename: str,
|
36 |
+
# src_lang: str = "lo",
|
37 |
+
# tgt_lang: str = "vi"
|
38 |
+
# ) -> List[Dict[str, Dict[str,str]]]:
|
39 |
+
# ds = []
|
40 |
+
# file_path = get_full_file_path(dataset_filename)
|
41 |
+
# with open(file_path, 'r', encoding='utf-8') as file:
|
42 |
+
# lines = file.readlines()
|
43 |
+
|
44 |
+
# for index, line in enumerate(lines):
|
45 |
+
# line = line.split(sep=None)
|
46 |
+
|
47 |
+
# lo_positions = [i for i, word in enumerate(line) if categorize_word(word) == "lo"]
|
48 |
+
# if len(lo_positions) == 0:
|
49 |
+
# # print(line)
|
50 |
+
# continue
|
51 |
+
|
52 |
+
# split_index = max(lo_positions)
|
53 |
+
# assert split_index is not None, f"Dataset error on line {index+1}."
|
54 |
+
|
55 |
+
# src_text = ' '.join(line[:split_index+1])
|
56 |
+
# tgt_text = line[split_index+1:]
|
57 |
+
|
58 |
+
# if index <= 5:
|
59 |
+
# print(src_text, tgt_text, sep="\n", end="\n-------")
|
60 |
+
|
61 |
+
# # TODO: post process the tgt_text to split all numbers in to single digits.
|
62 |
+
# ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
|
63 |
+
# return ds
|
64 |
+
|
65 |
+
# open_dataset('datasets/dev_clean.dat')
|
66 |
+
|
67 |
+
def load_local_dataset(
|
68 |
+
dataset_filename: str,
|
69 |
+
src_lang: str = "lo",
|
70 |
+
tgt_lang: str = "vi"
|
71 |
+
) -> List[Dict[str, Dict[str,str]]]:
|
72 |
+
ds = []
|
73 |
+
file_path = get_full_file_path(dataset_filename)
|
74 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
75 |
+
lines = file.readlines()
|
76 |
+
|
77 |
+
for index, line in enumerate(lines):
|
78 |
+
src_text, tgt_text = line.split(sep="\t", maxsplit=1)
|
79 |
+
ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
|
80 |
+
return ds
|
81 |
+
|
82 |
+
def load_local_bleu_dataset(
|
83 |
+
src_dataset_filename: str,
|
84 |
+
tgt_dataset_filename: str,
|
85 |
+
src_lang: str = "lo",
|
86 |
+
tgt_lang: str = "vi"
|
87 |
+
) -> List[Dict[str, Dict[str,str]]]:
|
88 |
+
def load_local_monolanguage_dataset(dataset_filename: str):
|
89 |
+
mono_ds = []
|
90 |
+
file_path = get_full_file_path(dataset_filename)
|
91 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
92 |
+
lines = file.readlines()
|
93 |
+
for line in lines:
|
94 |
+
mono_ds.append(line)
|
95 |
+
return mono_ds
|
96 |
+
|
97 |
+
src_texts = load_local_monolanguage_dataset(src_dataset_filename)
|
98 |
+
tgt_texts = load_local_monolanguage_dataset(tgt_dataset_filename)
|
99 |
+
|
100 |
+
assert len(src_texts) == len(tgt_texts)
|
101 |
+
ds = []
|
102 |
+
for i in range(len(src_texts)):
|
103 |
+
ds.append({'translation':{src_lang:src_texts[i], tgt_lang:tgt_texts[i]}})
|
104 |
+
return ds
|
tokenizer.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import Tokenizer
|
2 |
+
from tokenizers.models import WordLevel, BPE
|
3 |
+
from tokenizers.trainers import WordLevelTrainer, BpeTrainer
|
4 |
+
from tokenizers.pre_tokenizers import Whitespace, ByteLevel
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
def get_all_sentences(ds, lang: str):
|
10 |
+
for item in ds:
|
11 |
+
yield item['translation'][lang]
|
12 |
+
|
13 |
+
def get_or_build_local_tokenizer(config, ds, lang: str, tokenizer_type: str, force_build: bool = False) -> Tokenizer:
|
14 |
+
tokenizer_path = Path(config['dataset']['tokenizer_file'].format(lang))
|
15 |
+
if not Path.exists(tokenizer_path) or force_build:
|
16 |
+
if ds is None:
|
17 |
+
raise ValueError("Cannot find local tokenizer, dataset given is None")
|
18 |
+
|
19 |
+
if tokenizer_type == "WordLevel":
|
20 |
+
tokenizer = Tokenizer(WordLevel(unk_token='<unk>'))
|
21 |
+
tokenizer.pre_tokenizer = Whitespace()
|
22 |
+
trainer = WordLevelTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
|
23 |
+
elif tokenizer_type == "BPE":
|
24 |
+
tokenizer = Tokenizer(BPE(unk_token='<unk>'))
|
25 |
+
tokenizer.pre_tokenizer = Whitespace()
|
26 |
+
trainer = BpeTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
|
27 |
+
else:
|
28 |
+
raise ValueError("Unsupported Tokenizer type")
|
29 |
+
|
30 |
+
tokenizer.train_from_iterator(
|
31 |
+
get_all_sentences(ds, lang), trainer=trainer
|
32 |
+
)
|
33 |
+
tokenizer.save(str(tokenizer_path))
|
34 |
+
else:
|
35 |
+
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
36 |
+
return tokenizer
|
train.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.utils.data import DataLoader, Dataset
|
6 |
+
import torchmetrics
|
7 |
+
from torch.utils.tensorboard import SummaryWriter
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# from datasets import load_dataset
|
12 |
+
from load_dataset import load_local_dataset
|
13 |
+
from transformer import get_model
|
14 |
+
from config import load_config, get_weights_file_path
|
15 |
+
from validate import run_validation
|
16 |
+
from tokenizer import get_or_build_local_tokenizer
|
17 |
+
|
18 |
+
from pathlib import Path
|
19 |
+
|
20 |
+
from dataset import BilingualDataset
|
21 |
+
from bleu import calculate_bleu_score
|
22 |
+
from decode_method import greedy_decode
|
23 |
+
|
24 |
+
def get_local_dataset_tokenizer(config):
|
25 |
+
train_ds_raw = load_local_dataset(
|
26 |
+
dataset_filename='datasets/'+config['dataset']['train_dataset'],
|
27 |
+
src_lang=config['dataset']['src_lang'],
|
28 |
+
tgt_lang=config['dataset']['tgt_lang']
|
29 |
+
)
|
30 |
+
val_ds_raw = load_local_dataset(
|
31 |
+
dataset_filename='datasets/'+config['dataset']['validate_dataset'],
|
32 |
+
src_lang=config['dataset']['src_lang'],
|
33 |
+
tgt_lang=config['dataset']['tgt_lang']
|
34 |
+
)
|
35 |
+
|
36 |
+
src_tokenizer = get_or_build_local_tokenizer(
|
37 |
+
config=config,
|
38 |
+
ds=train_ds_raw + val_ds_raw,
|
39 |
+
lang=config['dataset']['src_lang'],
|
40 |
+
tokenizer_type=config['dataset']['src_tokenizer']
|
41 |
+
)
|
42 |
+
tgt_tokenizer = get_or_build_local_tokenizer(
|
43 |
+
config=config,
|
44 |
+
ds=train_ds_raw + val_ds_raw,
|
45 |
+
lang=config['dataset']['tgt_lang'],
|
46 |
+
tokenizer_type=config['dataset']['tgt_tokenizer']
|
47 |
+
)
|
48 |
+
|
49 |
+
train_ds = BilingualDataset(
|
50 |
+
ds=train_ds_raw,
|
51 |
+
src_tokenizer=src_tokenizer,
|
52 |
+
tgt_tokenizer=tgt_tokenizer,
|
53 |
+
src_lang=config['dataset']['src_lang'],
|
54 |
+
tgt_lang=config['dataset']['tgt_lang'],
|
55 |
+
src_max_seq_len=config['dataset']['src_max_seq_len'],
|
56 |
+
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
|
57 |
+
)
|
58 |
+
val_ds = BilingualDataset(
|
59 |
+
ds=val_ds_raw,
|
60 |
+
src_tokenizer=src_tokenizer,
|
61 |
+
tgt_tokenizer=tgt_tokenizer,
|
62 |
+
src_lang=config['dataset']['src_lang'],
|
63 |
+
tgt_lang=config['dataset']['tgt_lang'],
|
64 |
+
src_max_seq_len=config['dataset']['src_max_seq_len'],
|
65 |
+
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
|
66 |
+
)
|
67 |
+
|
68 |
+
src_max_seq_len = 0
|
69 |
+
tgt_max_seq_len = 0
|
70 |
+
for item in (train_ds_raw + val_ds_raw):
|
71 |
+
src_ids = src_tokenizer.encode(item['translation'][config['dataset']['src_lang']]).ids
|
72 |
+
tgt_ids = tgt_tokenizer.encode(item['translation'][config['dataset']['tgt_lang']]).ids
|
73 |
+
src_max_seq_len = max(src_max_seq_len, len(src_ids))
|
74 |
+
tgt_max_seq_len = max(tgt_max_seq_len, len(tgt_ids))
|
75 |
+
print(f'Max length of source sequence: {src_max_seq_len}')
|
76 |
+
print(f'Max length of target sequence: {tgt_max_seq_len}')
|
77 |
+
|
78 |
+
train_dataloader = DataLoader(train_ds, batch_size=config['train']['batch_size'], shuffle=True)
|
79 |
+
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
|
80 |
+
|
81 |
+
return train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer
|
82 |
+
|
83 |
+
def train_model(config):
|
84 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
85 |
+
print(f'Using device {device}')
|
86 |
+
|
87 |
+
Path(config['model']['model_folder']).mkdir(parents=True, exist_ok=True)
|
88 |
+
|
89 |
+
train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config)
|
90 |
+
model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device)
|
91 |
+
|
92 |
+
print(f'{src_tokenizer.get_vocab_size()}, {tgt_tokenizer.get_vocab_size()}')
|
93 |
+
|
94 |
+
#Tensorboard
|
95 |
+
writer = SummaryWriter(config['experiment_name'])
|
96 |
+
|
97 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'], eps=1e-9)
|
98 |
+
|
99 |
+
from transformers import get_linear_schedule_with_warmup
|
100 |
+
scheduler = get_linear_schedule_with_warmup(
|
101 |
+
optimizer,
|
102 |
+
num_warmup_steps=config['train']['warm_up_steps'],
|
103 |
+
num_training_steps=len(train_dataloader) * config['train']['num_epochs']+1
|
104 |
+
)
|
105 |
+
|
106 |
+
initial_epoch = 0
|
107 |
+
global_step = 0
|
108 |
+
if config['model']['preload']:
|
109 |
+
model_filename = get_weights_file_path(config, config['model']['preload'])
|
110 |
+
print(f'Preloading model from {model_filename}')
|
111 |
+
state = torch.load(model_filename, map_location=device)
|
112 |
+
|
113 |
+
initial_epoch = state['epoch']+1
|
114 |
+
model.load_state_dict(state['model_state_dict'])
|
115 |
+
optimizer.load_state_dict(state['optimizer_state_dict'])
|
116 |
+
scheduler.load_state_dict(state['scheduler_state_dict'])
|
117 |
+
global_step = state['global_step']
|
118 |
+
|
119 |
+
loss_fn = nn.CrossEntropyLoss(
|
120 |
+
ignore_index=src_tokenizer.token_to_id('<pad>'),
|
121 |
+
label_smoothing=config['train']['label_smoothing'],
|
122 |
+
).to(device)
|
123 |
+
|
124 |
+
print(f"Training model with {model.count_parameters()} params.")
|
125 |
+
|
126 |
+
patience = config['train']['patience']
|
127 |
+
best_state = {
|
128 |
+
'model_state_dict': model.state_dict(),
|
129 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
130 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
131 |
+
'loss': 9999999.99
|
132 |
+
}
|
133 |
+
|
134 |
+
for epoch in range(initial_epoch, config['train']['num_epochs']):
|
135 |
+
batch_iterator = tqdm(train_dataloader, desc=f'Proceesing epoch {epoch:02d}')
|
136 |
+
for batch in batch_iterator:
|
137 |
+
model.train()
|
138 |
+
|
139 |
+
encoder_input = batch['encoder_input'].to(device) # (batch, seq_len)
|
140 |
+
decoder_input = batch['decoder_input'].to(device) # (batch. seq_len)
|
141 |
+
encoder_mask = batch['encoder_mask'].to(device) # (batch, 1, 1, seq_len)
|
142 |
+
decoder_mask = batch['decoder_mask'].to(device) # (batch, 1, seq_len, seq_len)
|
143 |
+
|
144 |
+
encoder_output = model.encode(encoder_input, encoder_mask) # (batch, seq_len, d_model)
|
145 |
+
decoder_output, attn = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (batch, seq_len, d_model)
|
146 |
+
proj_output = model.project(decoder_output) # (batch, seq_len, tgt_vocab_size)
|
147 |
+
|
148 |
+
label = batch['label'].to(device) # (batch, seq_len)
|
149 |
+
|
150 |
+
loss = loss_fn(proj_output.view(-1, tgt_tokenizer.get_vocab_size()), label.view(-1))
|
151 |
+
batch_iterator.set_postfix({f"loss":f"{loss.item():6.3f}"})
|
152 |
+
|
153 |
+
writer.add_scalar('train_loss', loss.item(), global_step)
|
154 |
+
writer.flush()
|
155 |
+
|
156 |
+
global_step += 1
|
157 |
+
if global_step % patience == 0:
|
158 |
+
if loss > best_state['loss']:
|
159 |
+
model.load_state_dict(best_state['model_state_dict'])
|
160 |
+
optimizer.load_state_dict(best_state['optimizer_state_dict'])
|
161 |
+
scheduler.load_state_dict(best_state['scheduler_state_dict'])
|
162 |
+
continue
|
163 |
+
else:
|
164 |
+
best_state = {
|
165 |
+
'model_state_dict': model.state_dict(),
|
166 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
167 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
168 |
+
'loss': 9999999.99
|
169 |
+
}
|
170 |
+
loss.backward()
|
171 |
+
|
172 |
+
optimizer.step()
|
173 |
+
scheduler.step()
|
174 |
+
optimizer.zero_grad()
|
175 |
+
|
176 |
+
run_validation(model, val_dataloader, src_tokenizer, tgt_tokenizer, device, lambda msg: batch_iterator.write(msg), global_step, writer)
|
177 |
+
|
178 |
+
model_filename = get_weights_file_path(config, f'{epoch:02d}')
|
179 |
+
torch.save({
|
180 |
+
'epoch': epoch,
|
181 |
+
'model_state_dict': best_state['model_state_dict'],
|
182 |
+
'scheduler_state_dict': best_state['scheduler_state_dict'],
|
183 |
+
'optimizer_state_dict': best_state['optimizer_state_dict'],
|
184 |
+
'global_step': global_step,
|
185 |
+
}, model_filename)
|
186 |
+
|
187 |
+
# print(f"Bleu score: {calculate_bleu_score(model, val_dataloader, src_tokenizer, tgt_tokenizer, device)}")
|
188 |
+
|
189 |
+
if config['train']['on_colab']:
|
190 |
+
# if (epoch % 5) == 0:
|
191 |
+
# model_zip_filename = f'model_epoch_{epoch}.zip'
|
192 |
+
# os.system(f'zip -r {model_zip_filename} /content/silver-spoon/weights')
|
193 |
+
runs_zip_filename = f'runs_epoch_{epoch}.zip'
|
194 |
+
os.system(f"zip -r {runs_zip_filename} /content/silver-spoon/{config['experiment_name']}")
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == '__main__':
|
198 |
+
config = load_config(file_name='config.yaml')
|
199 |
+
train_model(config)
|
translate.py
CHANGED
@@ -6,7 +6,7 @@ from torch import Tensor
|
|
6 |
from tokenizers import Tokenizer
|
7 |
|
8 |
from transformer import Transformer
|
9 |
-
from decode_method import greedy_decode
|
10 |
|
11 |
def translate(
|
12 |
model: Transformer,
|
@@ -19,51 +19,61 @@ def translate(
|
|
19 |
"""
|
20 |
Translation function.
|
21 |
|
|
|
|
|
|
|
|
|
22 |
Output:
|
23 |
- translation (str): the translated string.
|
24 |
- attn (Tensor): The decoder's attention (for visualization)
|
25 |
"""
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
|
59 |
|
60 |
from config import load_config
|
61 |
from load_and_save_model import load_model_tokenizer
|
62 |
if __name__ == '__main__':
|
63 |
-
config = load_config(file_name='
|
64 |
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
|
65 |
text = "ສະບາຍດີ" # Hello.
|
66 |
translation, attn = translate(
|
67 |
-
model, src_tokenizer, tgt_tokenizer, text
|
|
|
68 |
)
|
69 |
print(translation)
|
|
|
6 |
from tokenizers import Tokenizer
|
7 |
|
8 |
from transformer import Transformer
|
9 |
+
from decode_method import greedy_decode, beam_search_decode
|
10 |
|
11 |
def translate(
|
12 |
model: Transformer,
|
|
|
19 |
"""
|
20 |
Translation function.
|
21 |
|
22 |
+
Supported `decode_method`: 'greedy' or 'beam-search'
|
23 |
+
|
24 |
+
'beam-search' doesn't give attn scores.
|
25 |
+
|
26 |
Output:
|
27 |
- translation (str): the translated string.
|
28 |
- attn (Tensor): The decoder's attention (for visualization)
|
29 |
"""
|
30 |
+
model.eval()
|
31 |
+
with torch.no_grad():
|
32 |
+
sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
|
33 |
+
eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
|
34 |
+
pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
|
35 |
+
|
36 |
+
encoder_input_tokens = src_tokenizer.encode(text).ids
|
37 |
+
# <sos> + source_text + <eos> = encoder_input
|
38 |
+
encoder_input = torch.cat(
|
39 |
+
[
|
40 |
+
sos_token,
|
41 |
+
torch.tensor(encoder_input_tokens, dtype=torch.int64),
|
42 |
+
eos_token,
|
43 |
+
]
|
44 |
+
)
|
45 |
+
encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)
|
46 |
|
47 |
+
encoder_input = encoder_input.unsqueeze(0)
|
48 |
+
# encoder_mask = torch.tensor(encoder_mask)
|
49 |
+
|
50 |
+
assert encoder_input.size(0) == 1
|
51 |
|
52 |
+
if decode_method == 'greedy':
|
53 |
+
model_out, attn = greedy_decode(
|
54 |
+
model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device,
|
55 |
+
give_attn=True,
|
56 |
+
)
|
57 |
+
elif decode_method == 'beam-search':
|
58 |
+
model_out = beam_search_decode(
|
59 |
+
model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device,
|
60 |
+
)
|
61 |
+
attn = None # Beam search doesn't give attention score
|
62 |
+
else:
|
63 |
+
raise ValueError("Unsuppored decode method")
|
64 |
|
65 |
+
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
|
66 |
+
return model_out_text, attn
|
67 |
|
68 |
|
69 |
from config import load_config
|
70 |
from load_and_save_model import load_model_tokenizer
|
71 |
if __name__ == '__main__':
|
72 |
+
config = load_config(file_name='config_huge.yaml')
|
73 |
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
|
74 |
text = "ສະບາຍດີ" # Hello.
|
75 |
translation, attn = translate(
|
76 |
+
model, src_tokenizer, tgt_tokenizer, text,
|
77 |
+
decode_method='beam-search',
|
78 |
)
|
79 |
print(translation)
|
utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
def get_full_file_path(file_name: str) -> str:
|
4 |
+
script_dir = Path(__file__).resolve().parent
|
5 |
+
file_path = script_dir / file_name
|
6 |
+
return file_path
|
validate.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.utils.data import DataLoader, Dataset
|
6 |
+
import torchmetrics
|
7 |
+
from torch.utils.tensorboard import SummaryWriter
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# from datasets import load_dataset
|
12 |
+
from load_dataset import load_local_dataset
|
13 |
+
from transformer import get_model, Transformer
|
14 |
+
from config import load_config, get_weights_file_path
|
15 |
+
|
16 |
+
from tokenizers import Tokenizer
|
17 |
+
from tokenizers.models import WordLevel, BPE
|
18 |
+
from tokenizers.trainers import WordLevelTrainer, BpeTrainer
|
19 |
+
from tokenizers.pre_tokenizers import Whitespace
|
20 |
+
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
from dataset import BilingualDataset
|
24 |
+
from bleu import calculate_bleu_score
|
25 |
+
from decode_method import greedy_decode
|
26 |
+
|
27 |
+
|
28 |
+
def run_validation(
|
29 |
+
model: Transformer,
|
30 |
+
validation_ds: DataLoader,
|
31 |
+
src_tokenizer: Tokenizer,
|
32 |
+
tgt_tokenizer: Tokenizer,
|
33 |
+
device,
|
34 |
+
print_msg,
|
35 |
+
global_state,
|
36 |
+
writer,
|
37 |
+
num_examples:int = 2
|
38 |
+
):
|
39 |
+
model.eval()
|
40 |
+
|
41 |
+
# inferance
|
42 |
+
count = 0
|
43 |
+
source_texts = []
|
44 |
+
expected = []
|
45 |
+
predicted = []
|
46 |
+
|
47 |
+
console_width = 50
|
48 |
+
with torch.no_grad():
|
49 |
+
for batch in validation_ds:
|
50 |
+
count += 1
|
51 |
+
encoder_input = batch['encoder_input'].to(device)
|
52 |
+
encoder_mask = batch['encoder_mask'].to(device)
|
53 |
+
|
54 |
+
assert encoder_input.size(0) == 1, "batch_size = 1 for validation"
|
55 |
+
|
56 |
+
model_out = greedy_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
|
57 |
+
|
58 |
+
source_text = batch['src_text'][0]
|
59 |
+
target_text = batch['tgt_text'][0]
|
60 |
+
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
|
61 |
+
|
62 |
+
source_texts.append(source_text)
|
63 |
+
expected.append(target_text)
|
64 |
+
predicted.append(model_out_text)
|
65 |
+
|
66 |
+
print_msg("-"*console_width)
|
67 |
+
print_msg(f"SOURCE: {source_text}")
|
68 |
+
print_msg(f"TARGET: {target_text}")
|
69 |
+
print_msg(f"PREDICTED: {model_out_text}")
|
70 |
+
|
71 |
+
if count == num_examples:
|
72 |
+
break
|