|
from typing import List, Dict, Any |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset |
|
from torch import Tensor |
|
|
|
from tokenizers import Tokenizer |
|
|
|
class BilingualDataset(Dataset): |
|
""" |
|
A Bilingual Dataset that follows the structure of the 'opus_books' dataset. |
|
""" |
|
def __init__( |
|
self, |
|
ds: List[Dict[str, Dict[str,str]]], |
|
src_tokenizer: Tokenizer, |
|
tgt_tokenizer: Tokenizer, |
|
src_lang: str, |
|
tgt_lang: str, |
|
src_max_seq_len: int, |
|
tgt_max_seq_len: int, |
|
) -> None: |
|
super(BilingualDataset, self).__init__() |
|
|
|
self.ds = ds |
|
self.src_tokenizer = src_tokenizer |
|
self.tgt_tokenizer = tgt_tokenizer |
|
self.src_lang = src_lang |
|
self.tgt_lang = tgt_lang |
|
|
|
self.src_max_seq_len = src_max_seq_len |
|
self.tgt_max_seq_len = tgt_max_seq_len |
|
|
|
self.sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64) |
|
self.eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64) |
|
self.pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64) |
|
|
|
def __len__(self): |
|
return len(self.ds) |
|
|
|
def __getitem__(self, index: int) -> Dict[str, Any]: |
|
src_tgt_pair = self.ds[index] |
|
src_text = src_tgt_pair['translation'][self.src_lang] |
|
tgt_text = src_tgt_pair['translation'][self.tgt_lang] |
|
|
|
encoder_input_tokens = self.src_tokenizer.encode(src_text).ids |
|
decoder_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids |
|
|
|
encoder_num_padding = self.src_max_seq_len - len(encoder_input_tokens) - 2 |
|
decoder_num_padding = self.tgt_max_seq_len - len(decoder_input_tokens) - 1 |
|
|
|
|
|
encoder_input = torch.cat( |
|
[ |
|
self.sos_token, |
|
torch.tensor(encoder_input_tokens, dtype=torch.int64), |
|
self.eos_token, |
|
torch.tensor([self.pad_token] * encoder_num_padding, dtype=torch.int64) |
|
] |
|
) |
|
|
|
decoder_input_tokens = torch.tensor(decoder_input_tokens, dtype=torch.int64) |
|
decoder_padding = torch.tensor([self.pad_token] * decoder_num_padding, dtype=torch.int64) |
|
|
|
decoder_input = torch.cat( |
|
[ |
|
self.sos_token, |
|
decoder_input_tokens, |
|
decoder_padding |
|
] |
|
) |
|
|
|
label = torch.cat( |
|
[ |
|
decoder_input_tokens, |
|
self.eos_token, |
|
decoder_padding |
|
] |
|
) |
|
|
|
assert encoder_input.size(0) == self.src_max_seq_len |
|
assert decoder_input.size(0) == self.tgt_max_seq_len |
|
assert label.size(0) == self.tgt_max_seq_len |
|
|
|
return { |
|
'encoder_input': encoder_input, |
|
'decoder_input': decoder_input, |
|
'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), |
|
'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), |
|
'label': label, |
|
'src_text': src_text, |
|
'tgt_text': tgt_text, |
|
} |
|
|
|
def causal_mask(size: int) -> Tensor: |
|
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int) |
|
return mask == 0 |
|
|