File size: 3,633 Bytes
b8a6dde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from typing import List, Dict, Any
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch import Tensor
from tokenizers import Tokenizer
class BilingualDataset(Dataset):
"""
A Bilingual Dataset that follows the structure of the 'opus_books' dataset.
"""
def __init__(
self,
ds: List[Dict[str, Dict[str,str]]],
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
src_lang: str,
tgt_lang: str,
src_max_seq_len: int,
tgt_max_seq_len: int,
) -> None:
super(BilingualDataset, self).__init__()
self.ds = ds
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.src_max_seq_len = src_max_seq_len
self.tgt_max_seq_len = tgt_max_seq_len
self.sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
self.eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
self.pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
def __len__(self):
return len(self.ds)
def __getitem__(self, index: int) -> Dict[str, Any]:
src_tgt_pair = self.ds[index]
src_text = src_tgt_pair['translation'][self.src_lang]
tgt_text = src_tgt_pair['translation'][self.tgt_lang]
encoder_input_tokens = self.src_tokenizer.encode(src_text).ids
decoder_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids
encoder_num_padding = self.src_max_seq_len - len(encoder_input_tokens) - 2 # <sos> + <eos>
decoder_num_padding = self.tgt_max_seq_len - len(decoder_input_tokens) - 1 # <sos>
# <sos> + source_text + <eos> + <pad> = encoder_input
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(encoder_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * encoder_num_padding, dtype=torch.int64)
]
)
decoder_input_tokens = torch.tensor(decoder_input_tokens, dtype=torch.int64)
decoder_padding = torch.tensor([self.pad_token] * decoder_num_padding, dtype=torch.int64)
# <sos> + target_text + <pad> = decoder_input
decoder_input = torch.cat(
[
self.sos_token,
decoder_input_tokens,
decoder_padding
]
)
# target_text + <eos> + <pad> = expected decoder_output (label)
label = torch.cat(
[
decoder_input_tokens,
self.eos_token,
decoder_padding
]
)
assert encoder_input.size(0) == self.src_max_seq_len
assert decoder_input.size(0) == self.tgt_max_seq_len
assert label.size(0) == self.tgt_max_seq_len
return {
'encoder_input': encoder_input, # (seq_len)
'decoder_input': decoder_input, # (seq_len)
'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len, seq_len)
'label': label, # (seq_len)
'src_text': src_text,
'tgt_text': tgt_text,
}
def causal_mask(size: int) -> Tensor:
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
return mask == 0
|