File size: 1,437 Bytes
751936e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
751936e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import json
from tokenizers import Tokenizer
from data_sample.oov_base import jd_vocab_tokens
from zhon.hanzi import punctuation as zh_punc

def load_base_tokenizer(tokenizer_path):
    print("loading", tokenizer_path)
    data = json.load(open(tokenizer_path, "r", encoding="utf-8"))
    tokenizer = Tokenizer.from_file(tokenizer_path)
    print("vocab_size with added_tokens:", tokenizer.get_vocab_size(with_added_tokens=True))
    return data, tokenizer


def append_token(word_list, base_tokenizer,  unused_ids=None):
    """
    append token to the end of vocab
    """
    new_vocab = set()
    new_merges = set()

    data, base_tokenizer = base_tokenizer
    vocab = data["model"]["vocab"]
    merges = data["model"]["merges"]
    vocab_size = base_tokenizer.basic_count(with_added_tokens=True)

    for word in word_list:
        encoding = base_tokenizer.encode(word)
        if len(encoding.ids) == 1:
            continue

        if len(encoding.ids) >= 4:
            print("[ERROR]: encoding不能超过4", word, encoding)

        tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids]
        if "\u00e6\u00a5\u0143" in tokens:
            print(word)

add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")]
add_words = [token for token in add_tokens if len(token) > 1]
new_tokenizer = load_base_tokenizer("20B_tokenizer.1.json")

append_token(add_words, new_tokenizer)