|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from fairseq.data import MonolingualDataset |
|
from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig |
|
from tests import utils as test_utils |
|
|
|
|
|
class TestLMContextWindow(unittest.TestCase): |
|
|
|
def test_eval_dataloader(self): |
|
dictionary = test_utils.dummy_dictionary(10) |
|
assert len(dictionary) == 14 |
|
assert dictionary.pad() == 1 |
|
|
|
dataset = test_utils.TestDataset([ |
|
torch.tensor([4, 5, 6, 7], dtype=torch.long), |
|
torch.tensor([8, 9, 10, 11], dtype=torch.long), |
|
torch.tensor([12, 13], dtype=torch.long), |
|
]) |
|
dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) |
|
|
|
config = LanguageModelingConfig(tokens_per_sample=4) |
|
task = LanguageModelingTask(config, dictionary) |
|
|
|
eval_dataloader = task.eval_lm_dataloader( |
|
dataset=dataset, |
|
batch_size=1, |
|
context_window=2, |
|
) |
|
|
|
batch = next(eval_dataloader) |
|
assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1] |
|
assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] |
|
|
|
batch = next(eval_dataloader) |
|
assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11] |
|
assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] |
|
|
|
batch = next(eval_dataloader) |
|
assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] |
|
assert batch["target"][0].tolist() == [1, 1, 12, 13] |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|