|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
class BilingualDataset(Dataset): |
|
|
|
def __init__(self, dataset, source_tokenizer, target_tokenizer, source_language, target_language, sequence_length): |
|
super().__init__() |
|
self.dataset = dataset |
|
self.source_tokenizer = source_tokenizer |
|
self.target_tokenizer = target_tokenizer |
|
self.source_language = source_language |
|
self.target_language = target_language |
|
self.sequence_length = sequence_length |
|
|
|
self.SOS_token = torch.tensor([target_tokenizer.token_to_id("[SOS]")], dtype=torch.int64) |
|
self.PAD_token = torch.tensor([target_tokenizer.token_to_id("[PAD]")], dtype= torch.int64) |
|
self.EOS_token = torch.tensor([target_tokenizer.token_to_id("[EOS]")], dtype= torch.int64) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, index) : |
|
source_target_dataset = self.dataset[index] |
|
source_text = source_target_dataset['translation'][self.source_language] |
|
target_text = source_target_dataset['translation'][self.target_language] |
|
|
|
encode_source_tokenizer = self.source_tokenizer.encode(source_text).ids |
|
encode_target_tokenizer = self.target_tokenizer.encode(target_text).ids |
|
|
|
encode_source_padding = self.sequence_length - len(encode_source_tokenizer) - 2 |
|
encode_target_padding = self.sequence_length - len(encode_target_tokenizer) - 1 |
|
|
|
if encode_source_padding < 0 or encode_target_padding < 0: |
|
raise ValueError("sequence is too long") |
|
|
|
encoder_input = torch.cat( |
|
[ |
|
self.SOS_token, |
|
torch.tensor(encode_source_tokenizer, dtype=torch.int64), |
|
self.EOS_token, |
|
torch.tensor([self.PAD_token] * encode_source_padding, dtype=torch.int64) |
|
] |
|
) |
|
|
|
decoder_input = torch.cat( |
|
[ |
|
self.SOS_token, |
|
torch.tensor(encode_target_tokenizer, dtype=torch.int64), |
|
torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64) |
|
] |
|
) |
|
|
|
Target = torch.cat( |
|
[ |
|
torch.tensor(encode_target_tokenizer, dtype=torch.int64), |
|
torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64), |
|
self.EOS_token |
|
] |
|
) |
|
|
|
assert encoder_input.size(0) == self.sequence_length |
|
assert decoder_input.size(0) == self.sequence_length |
|
assert Target.size(0) == self.sequence_length |
|
|
|
return { |
|
"encoder_input": encoder_input, |
|
"decoder_input": decoder_input, |
|
"encoder_input_mask": (encoder_input != self.PAD_token).unsqueeze(0).unsqueeze(0).int(), |
|
"decoder_input_mask": (decoder_input != self.PAD_token).unsqueeze(0).int() & casual_mask(decoder_input.size(0)), |
|
"Target": Target, |
|
"source_text": source_text, |
|
"target_text": target_text |
|
} |
|
|
|
|
|
def casual_mask(size): |
|
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) |
|
return mask == 0 |