File size: 4,894 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import shutil
import tempfile
import unittest

from fairseq import options
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
from .utils import create_dummy_data, preprocess_lm_data, train_language_model


def make_lm_config(
    data_dir=None,
    extra_flags=None,
    task="language_modeling",
    arch="transformer_lm_gpt2_tiny",
):
    task_args = [task]
    if data_dir is not None:
        task_args += [data_dir]
    train_parser = options.get_training_parser()
    train_args = options.parse_args_and_arch(
        train_parser,
        [
            "--task",
            *task_args,
            "--arch",
            arch,
            "--optimizer",
            "adam",
            "--lr",
            "0.0001",
            "--max-tokens",
            "500",
            "--tokens-per-sample",
            "500",
            "--save-dir",
            data_dir,
            "--max-epoch",
            "1",
        ]
        + (extra_flags or []),
    )
    cfg = convert_namespace_to_omegaconf(train_args)
    return cfg


def write_empty_file(path):
    with open(path, "w"):
        pass
    assert os.path.exists(path)


class TestValidSubsetsErrors(unittest.TestCase):
    """Test various filesystem, clarg combinations and ensure that error raising happens as expected"""

    def _test_case(self, paths, extra_flags):
        with tempfile.TemporaryDirectory() as data_dir:
            [
                write_empty_file(os.path.join(data_dir, f"{p}.bin"))
                for p in paths + ["train"]
            ]
            cfg = make_lm_config(data_dir, extra_flags=extra_flags)
            raise_if_valid_subsets_unintentionally_ignored(cfg)

    def test_default_raises(self):
        with self.assertRaises(ValueError):
            self._test_case(["valid", "valid1"], [])
        with self.assertRaises(ValueError):
            self._test_case(
                ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
            )

    def partially_specified_valid_subsets(self):
        with self.assertRaises(ValueError):
            self._test_case(
                ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
            )
        # Fix with ignore unused
        self._test_case(
            ["valid", "valid1", "valid2"],
            ["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
        )

    def test_legal_configs(self):
        self._test_case(["valid"], [])
        self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
        self._test_case(["valid", "valid1"], ["--combine-val"])
        self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
        self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
        self._test_case(
            ["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
        )
        self._test_case(
            ["valid1"], ["--valid-subset", "valid1"]
        )  # valid.bin doesn't need to be ignored.

    def test_disable_validation(self):
        self._test_case([], ["--disable-validation"])
        self._test_case(["valid", "valid1"], ["--disable-validation"])

    def test_dummy_task(self):
        cfg = make_lm_config(task="dummy_lm")
        raise_if_valid_subsets_unintentionally_ignored(cfg)

    def test_masked_dummy_task(self):
        cfg = make_lm_config(task="dummy_masked_lm")
        raise_if_valid_subsets_unintentionally_ignored(cfg)


class TestCombineValidSubsets(unittest.TestCase):
    def _train(self, extra_flags):
        with self.assertLogs() as logs:
            with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
                create_dummy_data(data_dir, num_examples=20)
                preprocess_lm_data(data_dir)

                shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
                shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
                train_language_model(
                    data_dir,
                    "transformer_lm",
                    ["--max-update", "0", "--log-format", "json"] + extra_flags,
                    run_validation=False,
                )
        return [x.message for x in logs.records]

    def test_combined(self):
        flags = ["--combine-valid-subsets"]
        logs = self._train(flags)
        assert any(["valid1" in x for x in logs])  # loaded 100 examples from valid1
        assert not any(["valid1_ppl" in x for x in logs])  # metrics are combined

    def test_subsets(self):
        flags = ["--valid-subset", "valid,valid1"]
        logs = self._train(flags)
        assert any(["valid_ppl" in x for x in logs])  # loaded 100 examples from valid1
        assert any(["valid1_ppl" in x for x in logs])  # metrics are combined