File size: 5,398 Bytes
d10ecd7
 
6f9d07b
d10ecd7
 
 
79b95c3
9495a4f
 
d10ecd7
 
9495a4f
 
d10ecd7
 
9495a4f
d10ecd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9495a4f
79b95c3
 
d10ecd7
 
9495a4f
d10ecd7
 
 
9495a4f
d10ecd7
 
 
 
 
 
 
 
 
9495a4f
d10ecd7
 
 
 
 
9495a4f
d10ecd7
 
9495a4f
d10ecd7
 
9495a4f
 
 
 
 
b15345c
d10ecd7
 
 
 
 
 
 
 
 
 
9495a4f
d10ecd7
 
 
9495a4f
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
9495a4f
 
d10ecd7
 
 
9495a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
 
 
 
 
 
9495a4f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import json
import socket
import pandas as pd
from vocab import load_tokener
from utils.zh_util import iter_vocab
from utils.log_util import logger
from functools import lru_cache
from urllib.parse import urlparse, parse_qs


@lru_cache
def tokenize(text, tokenizer_type, color_num=5):
    """
    """
    logger.info("param=" + json.dumps({"text": text, "tokenizer_type": tokenizer_type}, ensure_ascii=False))
    pos_tokens = []
    tokenizer = load_tokener(tokenizer_type)
    encoding = tokenizer.encode(text)

    table = []

    for idx, token_id in enumerate(encoding):
        decode_text = tokenizer.decode([token_id])  # 特殊字符解码后会统一变成 �,对应 "\ufffd"
        pos_tokens.extend([(decode_text, str(idx % color_num))])

        # token  "Byte":  # 这是 utf-8编码吧?
        token = tokenizer.convert_ids_to_tokens([token_id])[0]
        if isinstance(token, bytes):
            try:
                token_str = token.decode("utf-8")
            except:
                token_str = token.decode("utf-8", errors="ignore")
                logger.error("decode_error: " + json.dumps(
                    {"tokenizer_type": tokenizer_type, "token": str(token), "token_str": token_str},
                    ensure_ascii=False))

            token_bytes = token
            # json_dumps = json.dumps(token_str)
        elif isinstance(token, str):
            token_str = token
            token_bytes = bytes(token_str, "utf-8")
            # json_dumps = json.dumps(token_str)
        else:
            return

        # ⭐
        table.append(
            {"TokenID": token_id,
             "Token": token_str,  # utf-8解码后的字符串,为什么有些是 <0xE7>,表示什么?比如llama
             "Text": decode_text,  #
             # "Bytes": token_bytes,  # bytes类型在gradio前端页面被解码成字符串,比如   b'\xe4\xb8\xad' 仍然显示成 "中"。因此 str(token_bytes)
             "UTF8 Bytes": str(token_bytes),
             # "Unicode": json_dumps  # unicode, 如果是ascii码,就直接显示。如果不是ascii码,就显示unicode
             }
        )

    table_df = pd.DataFrame(table)
    logger.info(f"Tokens={table[:2]}")
    # print(table_df)

    return gr.update(value=pos_tokens, label=f"Tokens: {len(encoding)}"), table_df


@lru_cache
def tokenize_pair(text, tokenizer_type_1, tokenizer_type_2):
    """
    input_text.change
    """
    pos_tokens_1, table_df_1 = tokenize(text, tokenizer_type_1)
    pos_tokens_2, table_df_2 = tokenize(text, tokenizer_type_2)
    return pos_tokens_1, table_df_1, pos_tokens_2, table_df_2


def basic_count(tokenizer_type):
    tokenizer = load_tokener(tokenizer_type)
    stats = iter_vocab(tokenizer, tokenizer_type)
    return tokenizer.vocab_size, f'{stats["中文汉字数"]["中文单字"]}/{stats["中文汉字数"]["中文多字"]}'


@lru_cache
def get_overlap_token_size(tokenizer_type_1, tokenizer_type_2):
    tokenizer1 = load_tokener(tokenizer_type_1)
    tokenizer2 = load_tokener(tokenizer_type_2)

    vocab_set_1 = tokenizer1.get_vocab().keys()
    vocab_set_2 = tokenizer2.get_vocab().keys()

    token1 = next(iter(vocab_set_1))
    token2 = next(iter(vocab_set_2))
    if type(token1) != type(token2):  # bytes  str
        if isinstance(token1, str):
            vocab_set_1 = set([token.encode("utf-8") for token in vocab_set_1])
        if isinstance(token2, str):
            vocab_set_2 = set([token.encode("utf-8") for token in vocab_set_2])

    overlap_tokens = vocab_set_1 & vocab_set_2
    overlap_token_size = len(overlap_tokens)
    logger.info(
        f"{overlap_token_size} OverlapTokens of {tokenizer_type_1} {tokenizer_type_2}: {list(overlap_tokens)[:10]}")
    return overlap_token_size, overlap_token_size


default_user_input = """Replace this text in the input field to see how tokenization works
华为发布Mate60手机
ラグビーワールドカップ2023フランス"""
default_tokenizer_type_1 = "llama"
# default_tokenizer_type_2 = "internlm_chat_7b"
default_tokenizer_type_2 = "gpt_35_turbo"


def on_load(request: gr.Request):
    """
    onLoad
    """
    text = None
    tokenizer_type_1 = None
    tokenizer_type_2 = None
    query_params = {}
    if request:
        client_ip = request.client.host
        # local_ip = socket.gethostbyname(socket.gethostbyname(""))
        # headers = request.kwargs['headers']
        # if headers and 'x-forwarded-for' in headers:
        #     x_forwarded_for = headers['x-forwarded-for']
        #     client_ip = x_forwarded_for.split(' ')[0] if x_forwarded_for else ""
        if "referer" in request.headers:
            query_params = parse_qs(urlparse(request.headers["referer"]).query)
            query_params = {k: v[0] for k, v in query_params.items() if len(v) > 0}
        tokenizer_type_1 = query_params.get("tokenizer1", default_tokenizer_type_1)
        tokenizer_type_2 = query_params.get("tokenizer2", default_tokenizer_type_2)
        text = query_params.get("text", default_user_input)
        logger.info(f"client_ip: {client_ip}; params: {query_params}")
    return text, tokenizer_type_1, tokenizer_type_2


def test_coding():
    bytes1 = b'\xe4\xb8\xad'
    print(bytes1)  # b'\xe4\xb8\xad'


if __name__ == "__main__":
    print(get_overlap_token_size("gpt_35_turbo", "gpt_4"))
    # print(basic_count("internlm_chat_7b"))