File size: 3,464 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import io
import tempfile
import unittest

import torch
from fairseq.data import Dictionary


class TestDictionary(unittest.TestCase):
    def test_finalize(self):
        txt = [
            "A B C D",
            "B C D",
            "C D",
            "D",
        ]
        ref_ids1 = list(
            map(
                torch.IntTensor,
                [
                    [4, 5, 6, 7, 2],
                    [5, 6, 7, 2],
                    [6, 7, 2],
                    [7, 2],
                ],
            )
        )
        ref_ids2 = list(
            map(
                torch.IntTensor,
                [
                    [7, 6, 5, 4, 2],
                    [6, 5, 4, 2],
                    [5, 4, 2],
                    [4, 2],
                ],
            )
        )

        # build dictionary
        d = Dictionary()
        for line in txt:
            d.encode_line(line, add_if_not_exist=True)

        def get_ids(dictionary):
            ids = []
            for line in txt:
                ids.append(dictionary.encode_line(line, add_if_not_exist=False))
            return ids

        def assertMatch(ids, ref_ids):
            for toks, ref_toks in zip(ids, ref_ids):
                self.assertEqual(toks.size(), ref_toks.size())
                self.assertEqual(0, (toks != ref_toks).sum().item())

        ids = get_ids(d)
        assertMatch(ids, ref_ids1)

        # check finalized dictionary
        d.finalize()
        finalized_ids = get_ids(d)
        assertMatch(finalized_ids, ref_ids2)

        # write to disk and reload
        with tempfile.NamedTemporaryFile(mode="w") as tmp_dict:
            d.save(tmp_dict.name)
            d = Dictionary.load(tmp_dict.name)
            reload_ids = get_ids(d)
            assertMatch(reload_ids, ref_ids2)
            assertMatch(finalized_ids, reload_ids)

    def test_overwrite(self):
        # for example, Camembert overwrites <unk>, <s> and </s>
        dict_file = io.StringIO(
            "<unk> 999 #fairseq:overwrite\n"
            "<s> 999 #fairseq:overwrite\n"
            "</s> 999 #fairseq:overwrite\n"
            ", 999\n"
            "▁de 999\n"
        )
        d = Dictionary()
        d.add_from_file(dict_file)
        self.assertEqual(d.index("<pad>"), 1)
        self.assertEqual(d.index("foo"), 3)
        self.assertEqual(d.index("<unk>"), 4)
        self.assertEqual(d.index("<s>"), 5)
        self.assertEqual(d.index("</s>"), 6)
        self.assertEqual(d.index(","), 7)
        self.assertEqual(d.index("▁de"), 8)

    def test_no_overwrite(self):
        # for example, Camembert overwrites <unk>, <s> and </s>
        dict_file = io.StringIO(
            "<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n"
        )
        d = Dictionary()
        with self.assertRaisesRegex(RuntimeError, "Duplicate"):
            d.add_from_file(dict_file)

    def test_space(self):
        # for example, character models treat space as a symbol
        dict_file = io.StringIO("  999\n" "a 999\n" "b 999\n")
        d = Dictionary()
        d.add_from_file(dict_file)
        self.assertEqual(d.index(" "), 4)
        self.assertEqual(d.index("a"), 5)
        self.assertEqual(d.index("b"), 6)


if __name__ == "__main__":
    unittest.main()