|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from fairseq.data import LanguagePairDataset, TokenBlockDataset |
|
from fairseq.data.concat_dataset import ConcatDataset |
|
from tests.test_train import mock_dict |
|
|
|
|
|
class TestConcatDataset(unittest.TestCase): |
|
def setUp(self): |
|
d = mock_dict() |
|
tokens_1 = torch.LongTensor([1]).view(1, -1) |
|
tokens_ds1 = TokenBlockDataset( |
|
tokens_1, |
|
sizes=[tokens_1.size(-1)], |
|
block_size=1, |
|
pad=0, |
|
eos=1, |
|
include_targets=False, |
|
) |
|
self.dataset_1 = LanguagePairDataset( |
|
tokens_ds1, tokens_ds1.sizes, d, shuffle=False |
|
) |
|
tokens_2 = torch.LongTensor([2]).view(1, -1) |
|
tokens_ds2 = TokenBlockDataset( |
|
tokens_2, |
|
sizes=[tokens_2.size(-1)], |
|
block_size=1, |
|
pad=0, |
|
eos=1, |
|
include_targets=False, |
|
) |
|
self.dataset_2 = LanguagePairDataset( |
|
tokens_ds2, tokens_ds2.sizes, d, shuffle=False |
|
) |
|
|
|
def test_concat_dataset_basics(self): |
|
d = ConcatDataset([self.dataset_1, self.dataset_2]) |
|
assert len(d) == 2 |
|
assert d[0]["source"][0] == 1 |
|
assert d[1]["source"][0] == 2 |
|
|
|
d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[1, 2]) |
|
assert len(d) == 3 |
|
assert d[0]["source"][0] == 1 |
|
assert d[1]["source"][0] == 2 |
|
assert d[2]["source"][0] == 2 |
|
|
|
d = ConcatDataset([self.dataset_1, self.dataset_2], sample_ratios=[2, 1]) |
|
assert len(d) == 3 |
|
assert d[0]["source"][0] == 1 |
|
assert d[1]["source"][0] == 1 |
|
assert d[2]["source"][0] == 2 |
|
|