tokenizer-arena / playground_app.py
xu-song's picture
update
6ef6bf4
raw
history blame
10.1 kB
# coding=utf-8
# author: xusong
# time: 2022/8/23 16:06
import gradio as gr
from vocab import tokenizer_factory
from playground_examples import example_types, example_fn
from playground_util import (tokenize,
tokenize_pair, basic_count,
get_overlap_token_size, on_load)
get_window_url_params = """
function(url_params) {
const params = new URLSearchParams(window.location.search);
url_params = JSON.stringify(Object.fromEntries(params));
return url_params;
}
"""
all_tokenizer_name = [(config.name_display, config.name_or_path) for config in tokenizer_factory.all_tokenizer_configs]
with gr.Blocks() as demo:
# links: https://www.coderstool.com/utf8-encoding-decoding
# 功能:输入文本,进行分词
# 分词器:常见的分词器有集中,
# 背景:方便分词、看词粒度、对比
with gr.Row():
gr.Markdown("## Input Text")
dropdown_examples = gr.Dropdown(
example_types,
value="Examples",
type="index",
allow_custom_value=True,
show_label=False,
container=False,
scale=0,
elem_classes="example-style"
)
user_input = gr.Textbox(
# value=default_user_input,
label="Input Text",
lines=5,
show_label=False,
)
gr.Markdown("## Tokenization")
# compress rate setting TODO: 将 这个模块调整到下面
# with gr.Accordion("Compress Rate Setting", open=True):
# gr.Markdown(
# "Please select corpus and unit of compress rate, get more details at [github](https://github.com/xu-song/tokenizer-arena/). ")
# with gr.Row():
# compress_rate_corpus = gr.CheckboxGroup(
# common_corpuses, # , "code"
# value=["cc100-en", "cc100-zh-Hans"],
# label="corpus",
# # info=""
# )
# compress_rate_unit = gr.Radio(
# common_units,
# value="b_tokens/g_bytes",
# label="unit",
# )
# TODO: Token Setting
# with gr.Accordion("Token Filter Setting", open=False):
# gr.Markdown(
# "Get total number of tokens which contain the following character)")
# gr.Radio(
# ["zh-Hans", "", "number", "space"],
# value="zh",
# )
with gr.Row():
with gr.Column(scale=6):
with gr.Group():
tokenizer_name_1 = gr.Dropdown(
all_tokenizer_name,
label="Tokenizer 1",
# value=default_tokenizer_name_1,
)
with gr.Group():
with gr.Row():
organization_1 = gr.TextArea(
label="Organization",
lines=1,
elem_classes="statistics",
)
stats_vocab_size_1 = gr.TextArea(
label="Vocab Size",
lines=1,
elem_classes="statistics"
)
# stats_zh_token_size_1 = gr.TextArea(
# label="ZH char/word",
# lines=1,
# elem_classes="statistics",
# )
# stats_compress_rate_1 = gr.TextArea(
# label="Compress Rate",
# lines=1,
# elem_classes="statistics",
# )
stats_overlap_token_size_1 = gr.TextArea(
# value=default_stats_overlap_token_size,
label="Overlap Tokens",
lines=1,
elem_classes="statistics"
)
# stats_3 = gr.TextArea(
# label="Compress Rate",
# lines=1,
# elem_classes="statistics"
# )
# https://www.onlinewebfonts.com/icon/418591
gr.Image("images/VS.svg", scale=1, show_label=False,
show_download_button=False, container=False,
show_share_button=False)
with gr.Column(scale=6):
with gr.Group():
tokenizer_name_2 = gr.Dropdown(
all_tokenizer_name,
label="Tokenizer 2",
# value=default_tokenizer_name_2
)
with gr.Group():
with gr.Row():
organization_2 = gr.TextArea(
label="Organization",
lines=1,
elem_classes="statistics",
)
stats_vocab_size_2 = gr.TextArea(
label="Vocab Size",
lines=1,
elem_classes="statistics"
)
# stats_zh_token_size_2 = gr.TextArea(
# label="ZH char/word", # 中文字/词
# lines=1,
# elem_classes="statistics",
# )
# stats_compress_rate_2 = gr.TextArea(
# label="Compress Rate",
# lines=1,
# elem_classes="statistics"
# )
stats_filtered_token_2 = gr.TextArea(
label="filtered tokens",
lines=1,
elem_classes="statistics",
visible=False
)
stats_overlap_token_size_2 = gr.TextArea(
label="Overlap Tokens",
lines=1,
elem_classes="statistics"
)
# TODO: 图 表 压缩率
with gr.Row():
# dynamic change label
with gr.Column():
output_text_1 = gr.Highlightedtext(
show_legend=False,
show_inline_category=False,
elem_classes="space-show"
)
with gr.Column():
output_text_2 = gr.Highlightedtext(
show_legend=False,
show_inline_category=False,
elem_classes="space-show"
)
with gr.Row():
output_table_1 = gr.Dataframe()
output_table_2 = gr.Dataframe()
# setting
# compress_rate_unit.change(compress_rate_unit_change, [compress_rate_unit],
# [stats_compress_rate_1, stats_compress_rate_2])
tokenizer_name_1.change(tokenize, [user_input, tokenizer_name_1],
[output_text_1, output_table_1])
tokenizer_name_1.change(basic_count, [tokenizer_name_1], [stats_vocab_size_1, organization_1])
tokenizer_name_1.change(get_overlap_token_size, [tokenizer_name_1, tokenizer_name_2],
[stats_overlap_token_size_1, stats_overlap_token_size_2])
# tokenizer_type_1.change(get_compress_rate, [tokenizer_type_1, compress_rate_corpus, compress_rate_unit],
# [stats_compress_rate_1])
# TODO: every=3
user_input.change(tokenize_pair,
[user_input, tokenizer_name_1, tokenizer_name_2],
[output_text_1, output_table_1, output_text_2, output_table_2], show_api=False) # , pass_request=1
tokenizer_name_2.change(tokenize, [user_input, tokenizer_name_2],
[output_text_2, output_table_2], show_api=False)
tokenizer_name_2.change(basic_count, [tokenizer_name_2], [stats_vocab_size_2, organization_2], show_api=False)
tokenizer_name_2.change(get_overlap_token_size, [tokenizer_name_1, tokenizer_name_2],
[stats_overlap_token_size_1, stats_overlap_token_size_2], show_api=False)
# tokenizer_type_2.change(get_compress_rate,
# [tokenizer_type_2, compress_rate_corpus, compress_rate_unit],
# [stats_compress_rate_2])
#
# compress_rate_unit.change(get_compress_rate,
# [tokenizer_type_1, compress_rate_corpus, compress_rate_unit],
# [stats_compress_rate_1])
# compress_rate_unit.change(get_compress_rate,
# [tokenizer_type_2, compress_rate_corpus, compress_rate_unit],
# [stats_compress_rate_2])
# compress_rate_corpus.change(get_compress_rate,
# [tokenizer_type_1, compress_rate_corpus, compress_rate_unit],
# [stats_compress_rate_1])
# compress_rate_corpus.change(get_compress_rate,
# [tokenizer_type_2, compress_rate_corpus, compress_rate_unit],
# [stats_compress_rate_2])
dropdown_examples.change(
example_fn,
dropdown_examples,
[user_input, tokenizer_name_1, tokenizer_name_2],
show_api=False
)
demo.load(
fn=on_load,
inputs=[user_input], # 这里只需要传个空object即可。
outputs=[user_input, tokenizer_name_1, tokenizer_name_2],
js=get_window_url_params,
show_api=False
)
if __name__ == "__main__":
# demo.queue(max_size=20).launch()
demo.launch()
# demo.launch(share=True)