moiduy04 commited on
Commit
b8a6dde
1 Parent(s): f1aa791

Upload 12 files

Browse files
Files changed (11) hide show
  1. bleu.py +49 -0
  2. bleu_test.py +34 -0
  3. dataset.py +99 -0
  4. decode_method.py +66 -1
  5. load_and_save_model.py +73 -0
  6. load_dataset.py +104 -0
  7. tokenizer.py +36 -0
  8. train.py +199 -0
  9. translate.py +42 -32
  10. utils.py +6 -0
  11. validate.py +72 -0
bleu.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchtext.data.metrics import bleu_score
4
+
5
+ from tqdm import tqdm
6
+
7
+ from decode_method import beam_search_decode
8
+ from transformer import Transformer
9
+
10
+ from tokenizers import Tokenizer
11
+
12
+
13
+ def calculate_bleu_score(
14
+ model: Transformer,
15
+ bleu_dataloader: DataLoader,
16
+ src_tokenizer: Tokenizer,
17
+ tgt_tokenizer: Tokenizer,
18
+ device = torch.device('cpu'),
19
+ num_samples: int = 9999999,
20
+ ):
21
+ """"""
22
+ model.eval()
23
+
24
+ # inferance
25
+ count = 0
26
+ expected = []
27
+ predicted = []
28
+
29
+ with torch.no_grad():
30
+ batch_iterator = tqdm(bleu_dataloader)
31
+ for batch in batch_iterator:
32
+ count += 1
33
+ encoder_input = batch['encoder_input'].to(device)
34
+ encoder_mask = batch['encoder_mask'].to(device)
35
+
36
+ assert encoder_input.size(0) == 1, "batch_size = 1 for bleu calculation"
37
+
38
+ model_out = beam_search_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
39
+
40
+ target_text = batch['tgt_text'][0]
41
+ model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
42
+
43
+ expected.append([target_text.split()])
44
+ predicted.append(model_out_text.split())
45
+
46
+ if count == num_samples:
47
+ break
48
+
49
+ return bleu_score(predicted, expected) * 100.0
bleu_test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from bleu import calculate_bleu_score
3
+ from load_dataset import load_local_bleu_dataset
4
+ from dataset import BilingualDataset
5
+ from config import load_config
6
+ from load_and_save_model import load_model_tokenizer
7
+
8
+
9
+ def get_bleu_of_model(config) -> float:
10
+ model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
11
+ bleu_ds_raw = load_local_bleu_dataset(
12
+ src_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['src_lang'],
13
+ tgt_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['tgt_lang'],
14
+ src_lang=config['dataset']['src_lang'],
15
+ tgt_lang=config['dataset']['tgt_lang'],
16
+ )
17
+ bleu_ds = BilingualDataset(
18
+ ds=bleu_ds_raw,
19
+ src_tokenizer=src_tokenizer,
20
+ tgt_tokenizer=tgt_tokenizer,
21
+ src_lang=config['dataset']['src_lang'],
22
+ tgt_lang=config['dataset']['tgt_lang'],
23
+ src_max_seq_len=config['dataset']['src_max_seq_len'],
24
+ tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
25
+ )
26
+ bleu_dataloader = DataLoader(bleu_ds, batch_size=1, shuffle=True)
27
+ return calculate_bleu_score(
28
+ model, bleu_dataloader, src_tokenizer, tgt_tokenizer,
29
+ )
30
+
31
+ if __name__ == '__main__':
32
+ for file_name in {'config_final.yaml', 'config_huge.yaml', 'config_big.yaml', 'config_small.yaml'}:
33
+ config = load_config(file_name)
34
+ print(get_bleu_of_model(config), f" is the BLEU of {file_name}", sep='')
dataset.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset
6
+ from torch import Tensor
7
+
8
+ from tokenizers import Tokenizer
9
+
10
+ class BilingualDataset(Dataset):
11
+ """
12
+ A Bilingual Dataset that follows the structure of the 'opus_books' dataset.
13
+ """
14
+ def __init__(
15
+ self,
16
+ ds: List[Dict[str, Dict[str,str]]],
17
+ src_tokenizer: Tokenizer,
18
+ tgt_tokenizer: Tokenizer,
19
+ src_lang: str,
20
+ tgt_lang: str,
21
+ src_max_seq_len: int,
22
+ tgt_max_seq_len: int,
23
+ ) -> None:
24
+ super(BilingualDataset, self).__init__()
25
+
26
+ self.ds = ds
27
+ self.src_tokenizer = src_tokenizer
28
+ self.tgt_tokenizer = tgt_tokenizer
29
+ self.src_lang = src_lang
30
+ self.tgt_lang = tgt_lang
31
+
32
+ self.src_max_seq_len = src_max_seq_len
33
+ self.tgt_max_seq_len = tgt_max_seq_len
34
+
35
+ self.sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
36
+ self.eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
37
+ self.pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
38
+
39
+ def __len__(self):
40
+ return len(self.ds)
41
+
42
+ def __getitem__(self, index: int) -> Dict[str, Any]:
43
+ src_tgt_pair = self.ds[index]
44
+ src_text = src_tgt_pair['translation'][self.src_lang]
45
+ tgt_text = src_tgt_pair['translation'][self.tgt_lang]
46
+
47
+ encoder_input_tokens = self.src_tokenizer.encode(src_text).ids
48
+ decoder_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids
49
+
50
+ encoder_num_padding = self.src_max_seq_len - len(encoder_input_tokens) - 2 # <sos> + <eos>
51
+ decoder_num_padding = self.tgt_max_seq_len - len(decoder_input_tokens) - 1 # <sos>
52
+
53
+ # <sos> + source_text + <eos> + <pad> = encoder_input
54
+ encoder_input = torch.cat(
55
+ [
56
+ self.sos_token,
57
+ torch.tensor(encoder_input_tokens, dtype=torch.int64),
58
+ self.eos_token,
59
+ torch.tensor([self.pad_token] * encoder_num_padding, dtype=torch.int64)
60
+ ]
61
+ )
62
+
63
+ decoder_input_tokens = torch.tensor(decoder_input_tokens, dtype=torch.int64)
64
+ decoder_padding = torch.tensor([self.pad_token] * decoder_num_padding, dtype=torch.int64)
65
+ # <sos> + target_text + <pad> = decoder_input
66
+ decoder_input = torch.cat(
67
+ [
68
+ self.sos_token,
69
+ decoder_input_tokens,
70
+ decoder_padding
71
+ ]
72
+ )
73
+ # target_text + <eos> + <pad> = expected decoder_output (label)
74
+ label = torch.cat(
75
+ [
76
+ decoder_input_tokens,
77
+ self.eos_token,
78
+ decoder_padding
79
+ ]
80
+ )
81
+
82
+ assert encoder_input.size(0) == self.src_max_seq_len
83
+ assert decoder_input.size(0) == self.tgt_max_seq_len
84
+ assert label.size(0) == self.tgt_max_seq_len
85
+
86
+ return {
87
+ 'encoder_input': encoder_input, # (seq_len)
88
+ 'decoder_input': decoder_input, # (seq_len)
89
+ 'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
90
+ 'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len, seq_len)
91
+ 'label': label, # (seq_len)
92
+ 'src_text': src_text,
93
+ 'tgt_text': tgt_text,
94
+ }
95
+
96
+ def causal_mask(size: int) -> Tensor:
97
+ mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
98
+ return mask == 0
99
+
decode_method.py CHANGED
@@ -47,4 +47,69 @@ def greedy_decode(
47
  break
48
  if give_attn:
49
  return (decoder_input.squeeze(0), attn)
50
- return decoder_input.squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  break
48
  if give_attn:
49
  return (decoder_input.squeeze(0), attn)
50
+ return decoder_input.squeeze(0)
51
+
52
+ def beam_search_decode(
53
+ model: Transformer,
54
+ src: Tensor,
55
+ src_mask: Tensor,
56
+ src_tokenizer: Tokenizer,
57
+ tgt_tokenizer: Tokenizer,
58
+ tgt_max_seq_len: int,
59
+ device,
60
+ beam_size: int = 3,
61
+ ):
62
+ sos_idx = src_tokenizer.token_to_id('<sos>')
63
+ eos_idx = src_tokenizer.token_to_id('<eos>')
64
+
65
+ # Precompute the encoder output and reuse it for every step
66
+ encoder_output = model.encode(src, src_mask)
67
+ # Initialize the decoder input with the sos token
68
+ decoder_initial_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
69
+
70
+ # Create a candidate list
71
+ candidates = [(decoder_initial_input, 1)]
72
+
73
+ while True:
74
+
75
+ # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
76
+ if any([cand.size(1) == tgt_max_seq_len for cand, _ in candidates]):
77
+ break
78
+
79
+ # Create a new list of candidates
80
+ new_candidates = []
81
+
82
+ for candidate, score in candidates:
83
+
84
+ # Do not expand candidates that have reached the eos token
85
+ if candidate[0][-1].item() == eos_idx:
86
+ continue
87
+
88
+ # Build the candidate's mask
89
+ candidate_mask = causal_mask(candidate.size(1)).type_as(src_mask).to(device)
90
+ # calculate output
91
+ out, attn = model.decode(encoder_output, src_mask, candidate, candidate_mask)
92
+ # get next token probabilities
93
+ prob = model.project(out[:, -1])
94
+ # get the top k candidates
95
+ topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
96
+ for i in range(beam_size):
97
+ # for each of the top k candidates, get the token and its probability
98
+ token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
99
+ token_prob = topk_prob[0][i].item()
100
+ # create a new candidate by appending the token to the current candidate
101
+ new_candidate = torch.cat([candidate, token], dim=1)
102
+ # We sum the log probabilities because the probabilities are in log space
103
+ new_candidates.append((new_candidate, score + token_prob))
104
+
105
+ # Sort the new candidates by their score
106
+ candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
107
+ # Keep only the top k candidates
108
+ candidates = candidates[:beam_size]
109
+
110
+ # If all the candidates have reached the eos token, stop
111
+ if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
112
+ break
113
+
114
+ # Return the best candidate
115
+ return candidates[0][0].squeeze()
load_and_save_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from transformer import get_model, Transformer
5
+ from config import load_config, get_weights_file_path
6
+ from train import get_local_dataset_tokenizer
7
+ from tokenizer import get_or_build_local_tokenizer
8
+
9
+ from tokenizers import Tokenizer
10
+
11
+
12
+ def load_train_data_and_save_model(config, model_name):
13
+ """
14
+ loads training data (model, optim, scheduler,...) and saves ONLY the model.
15
+ """
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ print(f'Using device {device}')
18
+
19
+ train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config)
20
+ model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device)
21
+
22
+ assert config['model']['preload'], 'where to preload model.'
23
+
24
+ model_load_filename = get_weights_file_path(config, config['model']['preload'])
25
+ print(f'Preloading model from train data in {model_load_filename}')
26
+ state = torch.load(model_load_filename, map_location=device)
27
+
28
+ model.load_state_dict(state['model_state_dict'])
29
+
30
+ model_save_filename = get_weights_file_path(config, model_name)
31
+ torch.save(model.state_dict(), model_save_filename)
32
+ print(f'Model saved at {model_save_filename}')
33
+
34
+ def load_model_tokenizer(
35
+ config,
36
+ device = torch.device('cpu'),
37
+ logs: bool = True,
38
+ ) -> Tuple[Transformer, Tokenizer, Tokenizer]:
39
+ """
40
+ Loads a local model and tokenizer from a given config
41
+ """
42
+ if config['model']['preload'] is None:
43
+ raise ValueError('Unspecified preload model')
44
+
45
+ src_tokenizer = get_or_build_local_tokenizer(
46
+ config=config,
47
+ ds=None,
48
+ lang=config['dataset']['src_lang'],
49
+ tokenizer_type=config['dataset']['src_tokenizer']
50
+ )
51
+ tgt_tokenizer = get_or_build_local_tokenizer(
52
+ config=config,
53
+ ds=None,
54
+ lang=config['dataset']['tgt_lang'],
55
+ tokenizer_type=config['dataset']['tgt_tokenizer']
56
+ )
57
+
58
+ model = get_model(
59
+ config,
60
+ src_tokenizer.get_vocab_size(),
61
+ tgt_tokenizer.get_vocab_size(),
62
+ ).to(device)
63
+
64
+ model_filename = get_weights_file_path(config, config['model']['preload'])
65
+ model.load_state_dict(
66
+ torch.load(model_filename, map_location=device)
67
+ )
68
+ print(f'Finish loading model and tokenizers')
69
+ return (model, src_tokenizer, tgt_tokenizer)
70
+
71
+ if __name__ == '__main__':
72
+ config = load_config(file_name='config_huge.yaml')
73
+ load_train_data_and_save_model(config, 'huge')
load_dataset.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+
3
+ from pathlib import Path
4
+
5
+ from utils import get_full_file_path
6
+
7
+ # SENTENCE_STOPPERS = {'!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'}
8
+ # VIETNAMESE_SPECIAL_CHARACTERS = {'à', 'á', 'ả', 'ã', 'ạ', 'â', 'ầ', 'ấ', 'ẩ', 'ẫ', 'ậ', 'ă', 'ằ', 'ắ', 'ẳ', 'ẵ', 'ặ', 'è', 'é', 'ẻ', 'ẽ', 'ẹ', 'ê', 'ề', 'ế', 'ể', 'ễ', 'ệ', 'ì', 'í', 'ỉ', 'ĩ', 'ị', 'ò', 'ó', 'ỏ', 'õ', 'ọ', 'ô', 'ồ', 'ố', 'ổ', 'ỗ', 'ộ', 'ơ', 'ờ', 'ớ', 'ở', 'ỡ', 'ợ', 'ù', 'ú', 'ủ', 'ũ', 'ụ', 'ư', 'ừ', 'ứ', 'ử', 'ữ', 'ự', 'ỳ', 'ý', 'ỷ', 'ỹ', 'ỵ'}
9
+
10
+ # def is_Vietnamese_character(char):
11
+ # return char.isalpha() or char in VIETNAMESE_SPECIAL_CHARACTERS
12
+
13
+ # def categorize_word(word: str) -> str:
14
+ # """
15
+ # Categoize word into 3 types:
16
+ # - "vi": likely Vietnamese.
17
+ # - "lo": likely Laos.
18
+ # - "num": a number
19
+ # """
20
+ # if any(char.isdigit() for char in word):
21
+ # return "num"
22
+
23
+ # for stopper in SENTENCE_STOPPERS:
24
+ # if word.endswith(stopper):
25
+ # word = word[:-1]
26
+ # if len(word) == 0:
27
+ # break
28
+
29
+ # if len(word) > 0 and any(not is_Vietnamese_character(char) for char in word):
30
+ # return "lo"
31
+ # else:
32
+ # return "vi"
33
+ #
34
+ # def open_dataset(
35
+ # dataset_filename: str,
36
+ # src_lang: str = "lo",
37
+ # tgt_lang: str = "vi"
38
+ # ) -> List[Dict[str, Dict[str,str]]]:
39
+ # ds = []
40
+ # file_path = get_full_file_path(dataset_filename)
41
+ # with open(file_path, 'r', encoding='utf-8') as file:
42
+ # lines = file.readlines()
43
+
44
+ # for index, line in enumerate(lines):
45
+ # line = line.split(sep=None)
46
+
47
+ # lo_positions = [i for i, word in enumerate(line) if categorize_word(word) == "lo"]
48
+ # if len(lo_positions) == 0:
49
+ # # print(line)
50
+ # continue
51
+
52
+ # split_index = max(lo_positions)
53
+ # assert split_index is not None, f"Dataset error on line {index+1}."
54
+
55
+ # src_text = ' '.join(line[:split_index+1])
56
+ # tgt_text = line[split_index+1:]
57
+
58
+ # if index <= 5:
59
+ # print(src_text, tgt_text, sep="\n", end="\n-------")
60
+
61
+ # # TODO: post process the tgt_text to split all numbers in to single digits.
62
+ # ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
63
+ # return ds
64
+
65
+ # open_dataset('datasets/dev_clean.dat')
66
+
67
+ def load_local_dataset(
68
+ dataset_filename: str,
69
+ src_lang: str = "lo",
70
+ tgt_lang: str = "vi"
71
+ ) -> List[Dict[str, Dict[str,str]]]:
72
+ ds = []
73
+ file_path = get_full_file_path(dataset_filename)
74
+ with open(file_path, 'r', encoding='utf-8') as file:
75
+ lines = file.readlines()
76
+
77
+ for index, line in enumerate(lines):
78
+ src_text, tgt_text = line.split(sep="\t", maxsplit=1)
79
+ ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
80
+ return ds
81
+
82
+ def load_local_bleu_dataset(
83
+ src_dataset_filename: str,
84
+ tgt_dataset_filename: str,
85
+ src_lang: str = "lo",
86
+ tgt_lang: str = "vi"
87
+ ) -> List[Dict[str, Dict[str,str]]]:
88
+ def load_local_monolanguage_dataset(dataset_filename: str):
89
+ mono_ds = []
90
+ file_path = get_full_file_path(dataset_filename)
91
+ with open(file_path, 'r', encoding='utf-8') as file:
92
+ lines = file.readlines()
93
+ for line in lines:
94
+ mono_ds.append(line)
95
+ return mono_ds
96
+
97
+ src_texts = load_local_monolanguage_dataset(src_dataset_filename)
98
+ tgt_texts = load_local_monolanguage_dataset(tgt_dataset_filename)
99
+
100
+ assert len(src_texts) == len(tgt_texts)
101
+ ds = []
102
+ for i in range(len(src_texts)):
103
+ ds.append({'translation':{src_lang:src_texts[i], tgt_lang:tgt_texts[i]}})
104
+ return ds
tokenizer.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from tokenizers.models import WordLevel, BPE
3
+ from tokenizers.trainers import WordLevelTrainer, BpeTrainer
4
+ from tokenizers.pre_tokenizers import Whitespace, ByteLevel
5
+
6
+ from pathlib import Path
7
+
8
+
9
+ def get_all_sentences(ds, lang: str):
10
+ for item in ds:
11
+ yield item['translation'][lang]
12
+
13
+ def get_or_build_local_tokenizer(config, ds, lang: str, tokenizer_type: str, force_build: bool = False) -> Tokenizer:
14
+ tokenizer_path = Path(config['dataset']['tokenizer_file'].format(lang))
15
+ if not Path.exists(tokenizer_path) or force_build:
16
+ if ds is None:
17
+ raise ValueError("Cannot find local tokenizer, dataset given is None")
18
+
19
+ if tokenizer_type == "WordLevel":
20
+ tokenizer = Tokenizer(WordLevel(unk_token='<unk>'))
21
+ tokenizer.pre_tokenizer = Whitespace()
22
+ trainer = WordLevelTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
23
+ elif tokenizer_type == "BPE":
24
+ tokenizer = Tokenizer(BPE(unk_token='<unk>'))
25
+ tokenizer.pre_tokenizer = Whitespace()
26
+ trainer = BpeTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2)
27
+ else:
28
+ raise ValueError("Unsupported Tokenizer type")
29
+
30
+ tokenizer.train_from_iterator(
31
+ get_all_sentences(ds, lang), trainer=trainer
32
+ )
33
+ tokenizer.save(str(tokenizer_path))
34
+ else:
35
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
36
+ return tokenizer
train.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ from torch.utils.data import DataLoader, Dataset
6
+ import torchmetrics
7
+ from torch.utils.tensorboard import SummaryWriter
8
+
9
+ from tqdm import tqdm
10
+
11
+ # from datasets import load_dataset
12
+ from load_dataset import load_local_dataset
13
+ from transformer import get_model
14
+ from config import load_config, get_weights_file_path
15
+ from validate import run_validation
16
+ from tokenizer import get_or_build_local_tokenizer
17
+
18
+ from pathlib import Path
19
+
20
+ from dataset import BilingualDataset
21
+ from bleu import calculate_bleu_score
22
+ from decode_method import greedy_decode
23
+
24
+ def get_local_dataset_tokenizer(config):
25
+ train_ds_raw = load_local_dataset(
26
+ dataset_filename='datasets/'+config['dataset']['train_dataset'],
27
+ src_lang=config['dataset']['src_lang'],
28
+ tgt_lang=config['dataset']['tgt_lang']
29
+ )
30
+ val_ds_raw = load_local_dataset(
31
+ dataset_filename='datasets/'+config['dataset']['validate_dataset'],
32
+ src_lang=config['dataset']['src_lang'],
33
+ tgt_lang=config['dataset']['tgt_lang']
34
+ )
35
+
36
+ src_tokenizer = get_or_build_local_tokenizer(
37
+ config=config,
38
+ ds=train_ds_raw + val_ds_raw,
39
+ lang=config['dataset']['src_lang'],
40
+ tokenizer_type=config['dataset']['src_tokenizer']
41
+ )
42
+ tgt_tokenizer = get_or_build_local_tokenizer(
43
+ config=config,
44
+ ds=train_ds_raw + val_ds_raw,
45
+ lang=config['dataset']['tgt_lang'],
46
+ tokenizer_type=config['dataset']['tgt_tokenizer']
47
+ )
48
+
49
+ train_ds = BilingualDataset(
50
+ ds=train_ds_raw,
51
+ src_tokenizer=src_tokenizer,
52
+ tgt_tokenizer=tgt_tokenizer,
53
+ src_lang=config['dataset']['src_lang'],
54
+ tgt_lang=config['dataset']['tgt_lang'],
55
+ src_max_seq_len=config['dataset']['src_max_seq_len'],
56
+ tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
57
+ )
58
+ val_ds = BilingualDataset(
59
+ ds=val_ds_raw,
60
+ src_tokenizer=src_tokenizer,
61
+ tgt_tokenizer=tgt_tokenizer,
62
+ src_lang=config['dataset']['src_lang'],
63
+ tgt_lang=config['dataset']['tgt_lang'],
64
+ src_max_seq_len=config['dataset']['src_max_seq_len'],
65
+ tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
66
+ )
67
+
68
+ src_max_seq_len = 0
69
+ tgt_max_seq_len = 0
70
+ for item in (train_ds_raw + val_ds_raw):
71
+ src_ids = src_tokenizer.encode(item['translation'][config['dataset']['src_lang']]).ids
72
+ tgt_ids = tgt_tokenizer.encode(item['translation'][config['dataset']['tgt_lang']]).ids
73
+ src_max_seq_len = max(src_max_seq_len, len(src_ids))
74
+ tgt_max_seq_len = max(tgt_max_seq_len, len(tgt_ids))
75
+ print(f'Max length of source sequence: {src_max_seq_len}')
76
+ print(f'Max length of target sequence: {tgt_max_seq_len}')
77
+
78
+ train_dataloader = DataLoader(train_ds, batch_size=config['train']['batch_size'], shuffle=True)
79
+ val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
80
+
81
+ return train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer
82
+
83
+ def train_model(config):
84
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
+ print(f'Using device {device}')
86
+
87
+ Path(config['model']['model_folder']).mkdir(parents=True, exist_ok=True)
88
+
89
+ train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config)
90
+ model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device)
91
+
92
+ print(f'{src_tokenizer.get_vocab_size()}, {tgt_tokenizer.get_vocab_size()}')
93
+
94
+ #Tensorboard
95
+ writer = SummaryWriter(config['experiment_name'])
96
+
97
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'], eps=1e-9)
98
+
99
+ from transformers import get_linear_schedule_with_warmup
100
+ scheduler = get_linear_schedule_with_warmup(
101
+ optimizer,
102
+ num_warmup_steps=config['train']['warm_up_steps'],
103
+ num_training_steps=len(train_dataloader) * config['train']['num_epochs']+1
104
+ )
105
+
106
+ initial_epoch = 0
107
+ global_step = 0
108
+ if config['model']['preload']:
109
+ model_filename = get_weights_file_path(config, config['model']['preload'])
110
+ print(f'Preloading model from {model_filename}')
111
+ state = torch.load(model_filename, map_location=device)
112
+
113
+ initial_epoch = state['epoch']+1
114
+ model.load_state_dict(state['model_state_dict'])
115
+ optimizer.load_state_dict(state['optimizer_state_dict'])
116
+ scheduler.load_state_dict(state['scheduler_state_dict'])
117
+ global_step = state['global_step']
118
+
119
+ loss_fn = nn.CrossEntropyLoss(
120
+ ignore_index=src_tokenizer.token_to_id('<pad>'),
121
+ label_smoothing=config['train']['label_smoothing'],
122
+ ).to(device)
123
+
124
+ print(f"Training model with {model.count_parameters()} params.")
125
+
126
+ patience = config['train']['patience']
127
+ best_state = {
128
+ 'model_state_dict': model.state_dict(),
129
+ 'scheduler_state_dict': scheduler.state_dict(),
130
+ 'optimizer_state_dict': optimizer.state_dict(),
131
+ 'loss': 9999999.99
132
+ }
133
+
134
+ for epoch in range(initial_epoch, config['train']['num_epochs']):
135
+ batch_iterator = tqdm(train_dataloader, desc=f'Proceesing epoch {epoch:02d}')
136
+ for batch in batch_iterator:
137
+ model.train()
138
+
139
+ encoder_input = batch['encoder_input'].to(device) # (batch, seq_len)
140
+ decoder_input = batch['decoder_input'].to(device) # (batch. seq_len)
141
+ encoder_mask = batch['encoder_mask'].to(device) # (batch, 1, 1, seq_len)
142
+ decoder_mask = batch['decoder_mask'].to(device) # (batch, 1, seq_len, seq_len)
143
+
144
+ encoder_output = model.encode(encoder_input, encoder_mask) # (batch, seq_len, d_model)
145
+ decoder_output, attn = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (batch, seq_len, d_model)
146
+ proj_output = model.project(decoder_output) # (batch, seq_len, tgt_vocab_size)
147
+
148
+ label = batch['label'].to(device) # (batch, seq_len)
149
+
150
+ loss = loss_fn(proj_output.view(-1, tgt_tokenizer.get_vocab_size()), label.view(-1))
151
+ batch_iterator.set_postfix({f"loss":f"{loss.item():6.3f}"})
152
+
153
+ writer.add_scalar('train_loss', loss.item(), global_step)
154
+ writer.flush()
155
+
156
+ global_step += 1
157
+ if global_step % patience == 0:
158
+ if loss > best_state['loss']:
159
+ model.load_state_dict(best_state['model_state_dict'])
160
+ optimizer.load_state_dict(best_state['optimizer_state_dict'])
161
+ scheduler.load_state_dict(best_state['scheduler_state_dict'])
162
+ continue
163
+ else:
164
+ best_state = {
165
+ 'model_state_dict': model.state_dict(),
166
+ 'scheduler_state_dict': scheduler.state_dict(),
167
+ 'optimizer_state_dict': optimizer.state_dict(),
168
+ 'loss': 9999999.99
169
+ }
170
+ loss.backward()
171
+
172
+ optimizer.step()
173
+ scheduler.step()
174
+ optimizer.zero_grad()
175
+
176
+ run_validation(model, val_dataloader, src_tokenizer, tgt_tokenizer, device, lambda msg: batch_iterator.write(msg), global_step, writer)
177
+
178
+ model_filename = get_weights_file_path(config, f'{epoch:02d}')
179
+ torch.save({
180
+ 'epoch': epoch,
181
+ 'model_state_dict': best_state['model_state_dict'],
182
+ 'scheduler_state_dict': best_state['scheduler_state_dict'],
183
+ 'optimizer_state_dict': best_state['optimizer_state_dict'],
184
+ 'global_step': global_step,
185
+ }, model_filename)
186
+
187
+ # print(f"Bleu score: {calculate_bleu_score(model, val_dataloader, src_tokenizer, tgt_tokenizer, device)}")
188
+
189
+ if config['train']['on_colab']:
190
+ # if (epoch % 5) == 0:
191
+ # model_zip_filename = f'model_epoch_{epoch}.zip'
192
+ # os.system(f'zip -r {model_zip_filename} /content/silver-spoon/weights')
193
+ runs_zip_filename = f'runs_epoch_{epoch}.zip'
194
+ os.system(f"zip -r {runs_zip_filename} /content/silver-spoon/{config['experiment_name']}")
195
+
196
+
197
+ if __name__ == '__main__':
198
+ config = load_config(file_name='config.yaml')
199
+ train_model(config)
translate.py CHANGED
@@ -6,7 +6,7 @@ from torch import Tensor
6
  from tokenizers import Tokenizer
7
 
8
  from transformer import Transformer
9
- from decode_method import greedy_decode
10
 
11
  def translate(
12
  model: Transformer,
@@ -19,51 +19,61 @@ def translate(
19
  """
20
  Translation function.
21
 
 
 
 
 
22
  Output:
23
  - translation (str): the translated string.
24
  - attn (Tensor): The decoder's attention (for visualization)
25
  """
26
- sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
27
- eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
28
- pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
29
-
30
- encoder_input_tokens = src_tokenizer.encode(text).ids
31
- # <sos> + source_text + <eos> = encoder_input
32
- encoder_input = torch.cat(
33
- [
34
- sos_token,
35
- torch.tensor(encoder_input_tokens, dtype=torch.int64),
36
- eos_token,
37
- ]
38
- )
39
- encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)
 
 
40
 
41
- encoder_input = encoder_input.unsqueeze(0)
42
- # encoder_mask = torch.tensor(encoder_mask)
43
-
44
- assert encoder_input.size(0) == 1
45
 
46
- if decode_method == 'greedy':
47
- model_out, attn = greedy_decode(
48
- model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 400, device,
49
- give_attn=True,
50
- )
51
- elif decode_method == 'beam-search':
52
- raise NotImplementedError
53
- else:
54
- raise ValueError("Unsuppored decode method")
 
 
 
55
 
56
- model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
57
- return model_out_text, attn
58
 
59
 
60
  from config import load_config
61
  from load_and_save_model import load_model_tokenizer
62
  if __name__ == '__main__':
63
- config = load_config(file_name='config_small.yaml')
64
  model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
65
  text = "ສະບາຍດີ" # Hello.
66
  translation, attn = translate(
67
- model, src_tokenizer, tgt_tokenizer, text
 
68
  )
69
  print(translation)
 
6
  from tokenizers import Tokenizer
7
 
8
  from transformer import Transformer
9
+ from decode_method import greedy_decode, beam_search_decode
10
 
11
  def translate(
12
  model: Transformer,
 
19
  """
20
  Translation function.
21
 
22
+ Supported `decode_method`: 'greedy' or 'beam-search'
23
+
24
+ 'beam-search' doesn't give attn scores.
25
+
26
  Output:
27
  - translation (str): the translated string.
28
  - attn (Tensor): The decoder's attention (for visualization)
29
  """
30
+ model.eval()
31
+ with torch.no_grad():
32
+ sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
33
+ eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
34
+ pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
35
+
36
+ encoder_input_tokens = src_tokenizer.encode(text).ids
37
+ # <sos> + source_text + <eos> = encoder_input
38
+ encoder_input = torch.cat(
39
+ [
40
+ sos_token,
41
+ torch.tensor(encoder_input_tokens, dtype=torch.int64),
42
+ eos_token,
43
+ ]
44
+ )
45
+ encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)
46
 
47
+ encoder_input = encoder_input.unsqueeze(0)
48
+ # encoder_mask = torch.tensor(encoder_mask)
49
+
50
+ assert encoder_input.size(0) == 1
51
 
52
+ if decode_method == 'greedy':
53
+ model_out, attn = greedy_decode(
54
+ model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device,
55
+ give_attn=True,
56
+ )
57
+ elif decode_method == 'beam-search':
58
+ model_out = beam_search_decode(
59
+ model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device,
60
+ )
61
+ attn = None # Beam search doesn't give attention score
62
+ else:
63
+ raise ValueError("Unsuppored decode method")
64
 
65
+ model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
66
+ return model_out_text, attn
67
 
68
 
69
  from config import load_config
70
  from load_and_save_model import load_model_tokenizer
71
  if __name__ == '__main__':
72
+ config = load_config(file_name='config_huge.yaml')
73
  model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
74
  text = "ສະບາຍດີ" # Hello.
75
  translation, attn = translate(
76
+ model, src_tokenizer, tgt_tokenizer, text,
77
+ decode_method='beam-search',
78
  )
79
  print(translation)
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def get_full_file_path(file_name: str) -> str:
4
+ script_dir = Path(__file__).resolve().parent
5
+ file_path = script_dir / file_name
6
+ return file_path
validate.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ from torch.utils.data import DataLoader, Dataset
6
+ import torchmetrics
7
+ from torch.utils.tensorboard import SummaryWriter
8
+
9
+ from tqdm import tqdm
10
+
11
+ # from datasets import load_dataset
12
+ from load_dataset import load_local_dataset
13
+ from transformer import get_model, Transformer
14
+ from config import load_config, get_weights_file_path
15
+
16
+ from tokenizers import Tokenizer
17
+ from tokenizers.models import WordLevel, BPE
18
+ from tokenizers.trainers import WordLevelTrainer, BpeTrainer
19
+ from tokenizers.pre_tokenizers import Whitespace
20
+
21
+ from pathlib import Path
22
+
23
+ from dataset import BilingualDataset
24
+ from bleu import calculate_bleu_score
25
+ from decode_method import greedy_decode
26
+
27
+
28
+ def run_validation(
29
+ model: Transformer,
30
+ validation_ds: DataLoader,
31
+ src_tokenizer: Tokenizer,
32
+ tgt_tokenizer: Tokenizer,
33
+ device,
34
+ print_msg,
35
+ global_state,
36
+ writer,
37
+ num_examples:int = 2
38
+ ):
39
+ model.eval()
40
+
41
+ # inferance
42
+ count = 0
43
+ source_texts = []
44
+ expected = []
45
+ predicted = []
46
+
47
+ console_width = 50
48
+ with torch.no_grad():
49
+ for batch in validation_ds:
50
+ count += 1
51
+ encoder_input = batch['encoder_input'].to(device)
52
+ encoder_mask = batch['encoder_mask'].to(device)
53
+
54
+ assert encoder_input.size(0) == 1, "batch_size = 1 for validation"
55
+
56
+ model_out = greedy_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
57
+
58
+ source_text = batch['src_text'][0]
59
+ target_text = batch['tgt_text'][0]
60
+ model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
61
+
62
+ source_texts.append(source_text)
63
+ expected.append(target_text)
64
+ predicted.append(model_out_text)
65
+
66
+ print_msg("-"*console_width)
67
+ print_msg(f"SOURCE: {source_text}")
68
+ print_msg(f"TARGET: {target_text}")
69
+ print_msg(f"PREDICTED: {model_out_text}")
70
+
71
+ if count == num_examples:
72
+ break