homemade_lo_vi / dataset.py
moiduy04's picture
Upload 12 files
b8a6dde
raw
history blame
3.63 kB
from typing import List, Dict, Any
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch import Tensor
from tokenizers import Tokenizer
class BilingualDataset(Dataset):
"""
A Bilingual Dataset that follows the structure of the 'opus_books' dataset.
"""
def __init__(
self,
ds: List[Dict[str, Dict[str,str]]],
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
src_lang: str,
tgt_lang: str,
src_max_seq_len: int,
tgt_max_seq_len: int,
) -> None:
super(BilingualDataset, self).__init__()
self.ds = ds
self.src_tokenizer = src_tokenizer
self.tgt_tokenizer = tgt_tokenizer
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.src_max_seq_len = src_max_seq_len
self.tgt_max_seq_len = tgt_max_seq_len
self.sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
self.eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
self.pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
def __len__(self):
return len(self.ds)
def __getitem__(self, index: int) -> Dict[str, Any]:
src_tgt_pair = self.ds[index]
src_text = src_tgt_pair['translation'][self.src_lang]
tgt_text = src_tgt_pair['translation'][self.tgt_lang]
encoder_input_tokens = self.src_tokenizer.encode(src_text).ids
decoder_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids
encoder_num_padding = self.src_max_seq_len - len(encoder_input_tokens) - 2 # <sos> + <eos>
decoder_num_padding = self.tgt_max_seq_len - len(decoder_input_tokens) - 1 # <sos>
# <sos> + source_text + <eos> + <pad> = encoder_input
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(encoder_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * encoder_num_padding, dtype=torch.int64)
]
)
decoder_input_tokens = torch.tensor(decoder_input_tokens, dtype=torch.int64)
decoder_padding = torch.tensor([self.pad_token] * decoder_num_padding, dtype=torch.int64)
# <sos> + target_text + <pad> = decoder_input
decoder_input = torch.cat(
[
self.sos_token,
decoder_input_tokens,
decoder_padding
]
)
# target_text + <eos> + <pad> = expected decoder_output (label)
label = torch.cat(
[
decoder_input_tokens,
self.eos_token,
decoder_padding
]
)
assert encoder_input.size(0) == self.src_max_seq_len
assert decoder_input.size(0) == self.tgt_max_seq_len
assert label.size(0) == self.tgt_max_seq_len
return {
'encoder_input': encoder_input, # (seq_len)
'decoder_input': decoder_input, # (seq_len)
'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len, seq_len)
'label': label, # (seq_len)
'src_text': src_text,
'tgt_text': tgt_text,
}
def causal_mask(size: int) -> Tensor:
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
return mask == 0