import pandas as pd import streamlit as st import numpy as np import torch import io import time @st.cache(show_spinner=True,allow_output_mutation=True) def load_model(model_name): if model_name.startswith('bert'): from transformers import BertTokenizer tokenizer = BertTokenizer.from_pretrained(model_name) elif model_name.startswith('gpt2'): from transformers import GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained(model_name) elif model_name.startswith('roberta'): from transformers import RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained(model_name) elif model_name.startswith('albert'): from transformers import AlbertTokenizer tokenizer = AlbertTokenizer.from_pretrained(model_name) return tokenizer def generate_markdown(text,color='black',font='Arial',size=20): return f"
{text}
" def TokenizeText(sentence,tokenizer_name): if len(sentence)>0: if tokenizer_name.startswith('gpt2'): input_sent = tokenizer(sentence)['input_ids'] else: input_sent = tokenizer(sentence)['input_ids'][1:-1] encoded_sent = [str(token) for token in input_sent] decoded_sent = [tokenizer.decode([token]) for token in input_sent] num_tokens = len(decoded_sent) #char_nums = [len(word)+2 for word in decoded_sent] #word_cols = st.columns(char_nums) #for word_col,word in zip(word_cols,decoded_sent): #with word_col: #st.write(word) #st.write(' '.join(encoded_sent)) #st.write(' '.join(decoded_sent)) st.markdown(generate_markdown(' '.join(encoded_sent),size=16), unsafe_allow_html=True) st.markdown(generate_markdown(' '.join(decoded_sent),size=16), unsafe_allow_html=True) st.markdown(generate_markdown(f'{num_tokens} tokens'), unsafe_allow_html=True) return num_tokens def DeTokenizeText(input_str): if len(input_str)>0: input_sent = [int(element) for element in input_str.strip().split(' ')] encoded_sent = [str(token) for token in input_sent] decoded_sent = [tokenizer.decode([token]) for token in input_sent] num_tokens = len(decoded_sent) #char_nums = [len(word)+2 for word in decoded_sent] #word_cols = st.columns(char_nums) #for word_col,word in zip(word_cols,decoded_sent): #with word_col: #st.write(word) #st.write(' '.join(encoded_sent)) #st.write(' '.join(decoded_sent)) st.markdown(generate_markdown(' '.join(decoded_sent)), unsafe_allow_html=True) return num_tokens if __name__=='__main__': # Config max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) # Title st.markdown(generate_markdown('Tokenizer Demo:',size=32), unsafe_allow_html=True) st.markdown(generate_markdown('quick and easy way to explore how tokenizers work',size=24), unsafe_allow_html=True) # Select and load the tokenizer tokenizer_name = st.sidebar.selectbox('Choose the tokenizer from below', ('bert-base-uncased','bert-large-cased', 'gpt2','gpt2-large', 'roberta-base','roberta-large', 'albert-base-v2','albert-xxlarge-v2'),index=7) tokenizer = load_model(tokenizer_name) comparison_mode = st.sidebar.checkbox('Compare two texts') detokenize = st.sidebar.checkbox('de-tokenize (make sure to type in integers separated by single spaces)') if comparison_mode: sent_cols = st.columns(2) num_tokens = {} sents = {} for sent_id, sent_col in enumerate(sent_cols): with sent_col: if detokenize: sentence = st.text_input(f'Tokenized IDs {sent_id+1}') num_tokens[f'sent_{sent_id+1}'] = DeTokenizeText(sentence) else: sentence = st.text_input(f'Text {sent_id+1}') num_tokens[f'sent_{sent_id+1}'] = TokenizeText(sentence,tokenizer_name) sents[f'sent_{sent_id+1}'] = sentence if len(sents['sent_1'])>0 and len(sents['sent_2'])>0: st.markdown(generate_markdown('Result: ',size=16), unsafe_allow_html=True) if num_tokens[f'sent_1']==num_tokens[f'sent_2']: st.markdown(generate_markdown('Matched! ',color='MediumAquamarine'), unsafe_allow_html=True) else: st.markdown(generate_markdown('Not Matched... ',color='Salmon'), unsafe_allow_html=True) else: if detokenize: sentence = st.text_input(f'Tokenized IDs') num_tokens = DeTokenizeText(sentence) else: sentence = st.text_input(f'Text') num_tokens = TokenizeText(sentence,tokenizer_name)