|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import tests.utils as test_utils |
|
import torch |
|
from fairseq.data import ( |
|
BacktranslationDataset, |
|
LanguagePairDataset, |
|
TransformEosDataset, |
|
) |
|
from fairseq.sequence_generator import SequenceGenerator |
|
|
|
|
|
class TestBacktranslationDataset(unittest.TestCase): |
|
def setUp(self): |
|
( |
|
self.tgt_dict, |
|
self.w1, |
|
self.w2, |
|
self.src_tokens, |
|
self.src_lengths, |
|
self.model, |
|
) = test_utils.sequence_generator_setup() |
|
|
|
dummy_src_samples = self.src_tokens |
|
|
|
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) |
|
self.cuda = torch.cuda.is_available() |
|
|
|
def _backtranslation_dataset_helper( |
|
self, |
|
remove_eos_from_input_src, |
|
remove_eos_from_output_src, |
|
): |
|
tgt_dataset = LanguagePairDataset( |
|
src=self.tgt_dataset, |
|
src_sizes=self.tgt_dataset.sizes, |
|
src_dict=self.tgt_dict, |
|
tgt=None, |
|
tgt_sizes=None, |
|
tgt_dict=None, |
|
) |
|
|
|
generator = SequenceGenerator( |
|
[self.model], |
|
tgt_dict=self.tgt_dict, |
|
max_len_a=0, |
|
max_len_b=200, |
|
beam_size=2, |
|
unk_penalty=0, |
|
) |
|
|
|
backtranslation_dataset = BacktranslationDataset( |
|
tgt_dataset=TransformEosDataset( |
|
dataset=tgt_dataset, |
|
eos=self.tgt_dict.eos(), |
|
|
|
remove_eos_from_src=remove_eos_from_input_src, |
|
), |
|
src_dict=self.tgt_dict, |
|
backtranslation_fn=( |
|
lambda sample: generator.generate([self.model], sample) |
|
), |
|
output_collater=TransformEosDataset( |
|
dataset=tgt_dataset, |
|
eos=self.tgt_dict.eos(), |
|
|
|
|
|
append_eos_to_tgt=remove_eos_from_input_src, |
|
remove_eos_from_src=remove_eos_from_output_src, |
|
).collater, |
|
cuda=self.cuda, |
|
) |
|
dataloader = torch.utils.data.DataLoader( |
|
backtranslation_dataset, |
|
batch_size=2, |
|
collate_fn=backtranslation_dataset.collater, |
|
) |
|
backtranslation_batch_result = next(iter(dataloader)) |
|
|
|
eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2 |
|
|
|
|
|
|
|
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) |
|
if remove_eos_from_output_src: |
|
expected_src = expected_src[:, :-1] |
|
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) |
|
generated_src = backtranslation_batch_result["net_input"]["src_tokens"] |
|
tgt_tokens = backtranslation_batch_result["target"] |
|
|
|
self.assertTensorEqual(expected_src, generated_src) |
|
self.assertTensorEqual(expected_tgt, tgt_tokens) |
|
|
|
def test_backtranslation_dataset_no_eos_in_output_src(self): |
|
self._backtranslation_dataset_helper( |
|
remove_eos_from_input_src=False, |
|
remove_eos_from_output_src=True, |
|
) |
|
|
|
def test_backtranslation_dataset_with_eos_in_output_src(self): |
|
self._backtranslation_dataset_helper( |
|
remove_eos_from_input_src=False, |
|
remove_eos_from_output_src=False, |
|
) |
|
|
|
def test_backtranslation_dataset_no_eos_in_input_src(self): |
|
self._backtranslation_dataset_helper( |
|
remove_eos_from_input_src=True, |
|
remove_eos_from_output_src=False, |
|
) |
|
|
|
def assertTensorEqual(self, t1, t2): |
|
self.assertEqual(t1.size(), t2.size(), "size mismatch") |
|
self.assertEqual(t1.ne(t2).long().sum(), 0) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|