import gradio as gr from transformers import AutoTokenizer import ast from collections import Counter import re import plotly.graph_objs as go import html import random import tiktoken import anthropic model_path = "models/" # Available models MODELS = ["Meta-Llama-3.1-8B", "gemma-2b", "gpt-3.5-turbo","gpt-4","gpt-4o"] openai_models = ["gpt-3.5-turbo","gpt-4","gpt-4o"] # Color palette visible on both light and dark themes COLOR_PALETTE = [ "#e6194B", "#3cb44b", "#ffe119", "#4363d8", "#f58231", "#911eb4", "#42d4f4", "#f032e6", "#bfef45", "#fabed4", "#469990", "#dcbeff", "#9A6324", "#fffac8", "#800000", "#aaffc3", "#808000", "#ffd8b1", "#000075", "#a9a9a9" ] def create_vertical_histogram(data, title): labels, values = zip(*data) if data else ([], []) fig = go.Figure(go.Bar( x=labels, y=values )) fig.update_layout( title=title, xaxis_title="Item", yaxis_title="Count", height=400, xaxis=dict(tickangle=-45) ) return fig def validate_input(input_type, input_value): if input_type == "Text": if not isinstance(input_value, str): return False, "Input must be a string for Text input type." elif input_type == "Token IDs": try: token_ids = ast.literal_eval(input_value) if not isinstance(token_ids, list) or not all(isinstance(id, int) for id in token_ids): return False, "Token IDs must be a list of integers." except (ValueError, SyntaxError): return False, "Invalid Token IDs format. Please provide a valid list of integers." return True, "" def process_text(text: str, model_name: str, api_key: str = None): if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]: tokenizer = AutoTokenizer.from_pretrained(model_path + model_name) token_ids = tokenizer.encode(text, add_special_tokens=True) tokens = tokenizer.convert_ids_to_tokens(token_ids) elif model_name in openai_models: encoding = tiktoken.encoding_for_model(model_name=model_name) token_ids = encoding.encode(text) tokens = [encoding.decode([id]) for id in token_ids] elif model_name == "Claude-3-Sonnet": if not api_key: raise ValueError("API key is required for Claude models") client = anthropic.Anthropic(api_key=api_key) tokenizer = client.get_tokenizer() token_ids = tokenizer.encode(text).ids tokens = [tokenizer.decode([id]) for id in token_ids] else: raise ValueError(f"Unsupported model: {model_name}") return text, tokens, token_ids def process_ids(ids: str, model_name: str, api_key: str = None): token_ids = ast.literal_eval(ids) if model_name in ["Meta-Llama-3.1-8B", "gemma-2b"]: tokenizer = AutoTokenizer.from_pretrained(model_path + model_name) text = tokenizer.decode(token_ids) tokens = tokenizer.convert_ids_to_tokens(token_ids) elif model_name == openai_models: encoding = tiktoken.encoding_for_model(model_name=model_name) text = encoding.decode(token_ids) tokens = [encoding.decode([id]) for id in token_ids] elif model_name == "Claude-3-Sonnet": client = anthropic.Anthropic(api_key=api_key) tokenizer = client.get_tokenizer() text = tokenizer.decode(token_ids) tokens = [tokenizer.decode([id]) for id in token_ids] else: raise ValueError(f"Unsupported model: {model_name}") return text, tokens, token_ids def get_token_color(token, token_colors): if token.startswith('<') and token.endswith('>'): return "#42d4f4" # Cyan for special tokens elif token == '▁' or token == ' ': return "#3cb44b" # Green for space tokens elif not token.isalnum(): return "#f032e6" # Magenta for special characters else: if token not in token_colors: token_colors[token] = random.choice(COLOR_PALETTE) return token_colors[token] def create_html_tokens(tokens): html_output = '
' token_colors = {} for token in tokens: color = get_token_color(token, token_colors) escaped_token = html.escape(token) html_output += f'{escaped_token}' html_output += '
' return html_output def process_input(input_type, input_value, model_name, api_key): is_valid, error_message = validate_input(input_type, input_value) if not is_valid: raise gr.Error(error_message) if input_type == "Text": text, tokens, token_ids = process_text(text=input_value, model_name=model_name, api_key=api_key) elif input_type == "Token IDs": text, tokens, token_ids = process_ids(ids=input_value, model_name=model_name, api_key=api_key) character_count = len(text) word_count = len(text.split()) space_count = sum(1 for token in tokens if token in ['▁', ' ']) special_char_count = sum(1 for token in tokens if not token.isalnum() and token not in ['▁', ' ']) words = re.findall(r'\b\w+\b', text.lower()) special_chars = re.findall(r'[^\w\s]', text) numbers = re.findall(r'\d+', text) most_common_words = Counter(words).most_common(10) most_common_special_chars = Counter(special_chars).most_common(10) most_common_numbers = Counter(numbers).most_common(10) words_hist = create_vertical_histogram(most_common_words, "Most Common Words") special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters") numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers") analysis = f"Token count: {len(tokens)}\n" analysis += f"Character count: {character_count}\n" analysis += f"Word count: {word_count}\n" analysis += f"Space tokens: {space_count}\n" analysis += f"Special character tokens: {special_char_count}\n" analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}" html_tokens = create_html_tokens(tokens) return analysis, text, html_tokens, str(token_ids), words_hist, special_chars_hist, numbers_hist def text_example(): return "Hello, world! This is an example text input for tokenization." def token_ids_example(): return "[128000, 9906, 11, 1917, 0, 1115, 374, 459, 3187, 1495, 1988, 369, 4037, 2065, 13]" with gr.Blocks() as iface: gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!") gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.") with gr.Row(): input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text") model_name = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]) # api_key = gr.Textbox(label="API Key Claude models)", type="password") input_text = gr.Textbox(lines=5, label="Input") with gr.Row(): text_example_button = gr.Button("Load Text Example") token_ids_example_button = gr.Button("Load Token IDs Example") submit_button = gr.Button("Process") analysis_output = gr.Textbox(label="Analysis", lines=6) text_output = gr.Textbox(label="Text", lines=6) tokens_output = gr.HTML(label="Tokens") token_ids_output = gr.Textbox(label="Token IDs", lines=2) with gr.Row(): words_plot = gr.Plot(label="Most Common Words") special_chars_plot = gr.Plot(label="Most Common Special Characters") numbers_plot = gr.Plot(label="Most Common Numbers") text_example_button.click( lambda: (text_example(), "Text"), outputs=[input_text, input_type] ) token_ids_example_button.click( lambda: (token_ids_example(), "Token IDs"), outputs=[input_text, input_type] ) submit_button.click( process_input, inputs=[input_type, input_text, model_name], outputs=[analysis_output, text_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot] ) if __name__ == "__main__": iface.launch()