File size: 4,894 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import shutil
import tempfile
import unittest
from fairseq import options
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
from .utils import create_dummy_data, preprocess_lm_data, train_language_model
def make_lm_config(
data_dir=None,
extra_flags=None,
task="language_modeling",
arch="transformer_lm_gpt2_tiny",
):
task_args = [task]
if data_dir is not None:
task_args += [data_dir]
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
"--task",
*task_args,
"--arch",
arch,
"--optimizer",
"adam",
"--lr",
"0.0001",
"--max-tokens",
"500",
"--tokens-per-sample",
"500",
"--save-dir",
data_dir,
"--max-epoch",
"1",
]
+ (extra_flags or []),
)
cfg = convert_namespace_to_omegaconf(train_args)
return cfg
def write_empty_file(path):
with open(path, "w"):
pass
assert os.path.exists(path)
class TestValidSubsetsErrors(unittest.TestCase):
"""Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
def _test_case(self, paths, extra_flags):
with tempfile.TemporaryDirectory() as data_dir:
[
write_empty_file(os.path.join(data_dir, f"{p}.bin"))
for p in paths + ["train"]
]
cfg = make_lm_config(data_dir, extra_flags=extra_flags)
raise_if_valid_subsets_unintentionally_ignored(cfg)
def test_default_raises(self):
with self.assertRaises(ValueError):
self._test_case(["valid", "valid1"], [])
with self.assertRaises(ValueError):
self._test_case(
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
)
def partially_specified_valid_subsets(self):
with self.assertRaises(ValueError):
self._test_case(
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
)
# Fix with ignore unused
self._test_case(
["valid", "valid1", "valid2"],
["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
)
def test_legal_configs(self):
self._test_case(["valid"], [])
self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
self._test_case(["valid", "valid1"], ["--combine-val"])
self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
self._test_case(
["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
)
self._test_case(
["valid1"], ["--valid-subset", "valid1"]
) # valid.bin doesn't need to be ignored.
def test_disable_validation(self):
self._test_case([], ["--disable-validation"])
self._test_case(["valid", "valid1"], ["--disable-validation"])
def test_dummy_task(self):
cfg = make_lm_config(task="dummy_lm")
raise_if_valid_subsets_unintentionally_ignored(cfg)
def test_masked_dummy_task(self):
cfg = make_lm_config(task="dummy_masked_lm")
raise_if_valid_subsets_unintentionally_ignored(cfg)
class TestCombineValidSubsets(unittest.TestCase):
def _train(self, extra_flags):
with self.assertLogs() as logs:
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
create_dummy_data(data_dir, num_examples=20)
preprocess_lm_data(data_dir)
shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
train_language_model(
data_dir,
"transformer_lm",
["--max-update", "0", "--log-format", "json"] + extra_flags,
run_validation=False,
)
return [x.message for x in logs.records]
def test_combined(self):
flags = ["--combine-valid-subsets"]
logs = self._train(flags)
assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
def test_subsets(self):
flags = ["--valid-subset", "valid,valid1"]
logs = self._train(flags)
assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
assert any(["valid1_ppl" in x for x in logs]) # metrics are combined
|