Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import io | |
import os | |
import string | |
import tempfile | |
import unittest | |
import torch | |
from fairseq import tokenizer | |
from fairseq.data import Dictionary | |
class TestDictionary(unittest.TestCase): | |
def test_finalize(self): | |
txt = [ | |
"A B C D", | |
"B C D", | |
"C D", | |
"D", | |
] | |
ref_ids1 = list( | |
map( | |
torch.IntTensor, | |
[ | |
[4, 5, 6, 7, 2], | |
[5, 6, 7, 2], | |
[6, 7, 2], | |
[7, 2], | |
], | |
) | |
) | |
ref_ids2 = list( | |
map( | |
torch.IntTensor, | |
[ | |
[7, 6, 5, 4, 2], | |
[6, 5, 4, 2], | |
[5, 4, 2], | |
[4, 2], | |
], | |
) | |
) | |
# build dictionary | |
d = Dictionary() | |
for line in txt: | |
d.encode_line(line, add_if_not_exist=True) | |
def get_ids(dictionary): | |
ids = [] | |
for line in txt: | |
ids.append(dictionary.encode_line(line, add_if_not_exist=False)) | |
return ids | |
def assertMatch(ids, ref_ids): | |
for toks, ref_toks in zip(ids, ref_ids): | |
self.assertEqual(toks.size(), ref_toks.size()) | |
self.assertEqual(0, (toks != ref_toks).sum().item()) | |
ids = get_ids(d) | |
assertMatch(ids, ref_ids1) | |
# check finalized dictionary | |
d.finalize() | |
finalized_ids = get_ids(d) | |
assertMatch(finalized_ids, ref_ids2) | |
# write to disk and reload | |
with tempfile.NamedTemporaryFile(mode="w") as tmp_dict: | |
d.save(tmp_dict.name) | |
d = Dictionary.load(tmp_dict.name) | |
reload_ids = get_ids(d) | |
assertMatch(reload_ids, ref_ids2) | |
assertMatch(finalized_ids, reload_ids) | |
def test_overwrite(self): | |
# for example, Camembert overwrites <unk>, <s> and </s> | |
dict_file = io.StringIO( | |
"<unk> 999 #fairseq:overwrite\n" | |
"<s> 999 #fairseq:overwrite\n" | |
"</s> 999 #fairseq:overwrite\n" | |
", 999\n" | |
"▁de 999\n" | |
) | |
d = Dictionary() | |
d.add_from_file(dict_file) | |
self.assertEqual(d.index("<pad>"), 1) | |
self.assertEqual(d.index("foo"), 3) | |
self.assertEqual(d.index("<unk>"), 4) | |
self.assertEqual(d.index("<s>"), 5) | |
self.assertEqual(d.index("</s>"), 6) | |
self.assertEqual(d.index(","), 7) | |
self.assertEqual(d.index("▁de"), 8) | |
def test_no_overwrite(self): | |
# for example, Camembert overwrites <unk>, <s> and </s> | |
dict_file = io.StringIO( | |
"<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n" | |
) | |
d = Dictionary() | |
with self.assertRaisesRegex(RuntimeError, "Duplicate"): | |
d.add_from_file(dict_file) | |
def test_space(self): | |
# for example, character models treat space as a symbol | |
dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n") | |
d = Dictionary() | |
d.add_from_file(dict_file) | |
self.assertEqual(d.index(" "), 4) | |
self.assertEqual(d.index("a"), 5) | |
self.assertEqual(d.index("b"), 6) | |
def test_add_file_to_dict(self): | |
counts = {} | |
num_lines = 100 | |
per_line = 10 | |
with tempfile.TemporaryDirectory("test_sampling") as data_dir: | |
filename = os.path.join(data_dir, "dummy.txt") | |
with open(filename, "w", encoding="utf-8") as data: | |
for c in string.ascii_letters: | |
line = f"{c} " * per_line | |
for _ in range(num_lines): | |
data.write(f"{line}\n") | |
counts[c] = per_line * num_lines | |
per_line += 5 | |
dict = Dictionary() | |
Dictionary.add_file_to_dictionary( | |
filename, dict, tokenizer.tokenize_line, 10 | |
) | |
dict.finalize(threshold=0, nwords=-1, padding_factor=8) | |
for c in string.ascii_letters: | |
count = dict.get_count(dict.index(c)) | |
self.assertEqual( | |
counts[c], count, f"{c} count is {count} but should be {counts[c]}" | |
) | |
if __name__ == "__main__": | |
unittest.main() | |