# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import io import tempfile import unittest import torch from fairseq.data import Dictionary class TestDictionary(unittest.TestCase): def test_finalize(self): txt = [ "A B C D", "B C D", "C D", "D", ] ref_ids1 = list( map( torch.IntTensor, [ [4, 5, 6, 7, 2], [5, 6, 7, 2], [6, 7, 2], [7, 2], ], ) ) ref_ids2 = list( map( torch.IntTensor, [ [7, 6, 5, 4, 2], [6, 5, 4, 2], [5, 4, 2], [4, 2], ], ) ) # build dictionary d = Dictionary() for line in txt: d.encode_line(line, add_if_not_exist=True) def get_ids(dictionary): ids = [] for line in txt: ids.append(dictionary.encode_line(line, add_if_not_exist=False)) return ids def assertMatch(ids, ref_ids): for toks, ref_toks in zip(ids, ref_ids): self.assertEqual(toks.size(), ref_toks.size()) self.assertEqual(0, (toks != ref_toks).sum().item()) ids = get_ids(d) assertMatch(ids, ref_ids1) # check finalized dictionary d.finalize() finalized_ids = get_ids(d) assertMatch(finalized_ids, ref_ids2) # write to disk and reload with tempfile.NamedTemporaryFile(mode="w") as tmp_dict: d.save(tmp_dict.name) d = Dictionary.load(tmp_dict.name) reload_ids = get_ids(d) assertMatch(reload_ids, ref_ids2) assertMatch(finalized_ids, reload_ids) def test_overwrite(self): # for example, Camembert overwrites , and dict_file = io.StringIO( " 999 #fairseq:overwrite\n" " 999 #fairseq:overwrite\n" " 999 #fairseq:overwrite\n" ", 999\n" "▁de 999\n" ) d = Dictionary() d.add_from_file(dict_file) self.assertEqual(d.index(""), 1) self.assertEqual(d.index("foo"), 3) self.assertEqual(d.index(""), 4) self.assertEqual(d.index(""), 5) self.assertEqual(d.index(""), 6) self.assertEqual(d.index(","), 7) self.assertEqual(d.index("▁de"), 8) def test_no_overwrite(self): # for example, Camembert overwrites , and dict_file = io.StringIO( " 999\n" " 999\n" " 999\n" ", 999\n" "▁de 999\n" ) d = Dictionary() with self.assertRaisesRegex(RuntimeError, "Duplicate"): d.add_from_file(dict_file) def test_space(self): # for example, character models treat space as a symbol dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n") d = Dictionary() d.add_from_file(dict_file) self.assertEqual(d.index(" "), 4) self.assertEqual(d.index("a"), 5) self.assertEqual(d.index("b"), 6) if __name__ == "__main__": unittest.main()