import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch import gradio as gr import logging from huggingface_hub import login import os import traceback from threading import Thread from random import shuffle logging.basicConfig(level=logging.DEBUG) SPACER = '\n' + '*' * 40 + '\n' HF_TOKEN = os.environ.get("HF_TOKEN", None) login(token=HF_TOKEN) system_prompts = { "English": "You are a helpful chatbot that answers user input in a concise and witty way.", "German": "Du bist ein hilfreicher Chatbot, der Usereingaben knapp und originell beantwortet.", "French": "Tu es un chatbot utile qui répond aux questions des utilisateurs de manière concise et originale.", "Spanish": "Eres un chatbot servicial que responde a las entradas de los usuarios de forma concisa y original." } model_info = [{"id": "NousResearch/Meta-Llama-3.1-8B-Instruct", "name": "Meta Llama 3.1 8B Instruct"}, {"id": "mistralai/Mistral-7B-Instruct-v0.3", "name": "Mistral 7B Instruct v0.3"}] shuffle(model_info) logging.debug('Models shuffled') device = "cuda" try: tokenizer_a = AutoTokenizer.from_pretrained(model_info[0]['id']) model_a = AutoModelForCausalLM.from_pretrained( model_info[0]['id'], torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) #model_a.tie_weights() tokenizer_b = AutoTokenizer.from_pretrained(model_info[1]['id']) model_b = AutoModelForCausalLM.from_pretrained( model_info[1]['id'], torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) model_b.tie_weights() except Exception as e: logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') def apply_chat_template(messages, add_generation_prompt=False): """Adds chat template for Pharia. Expects a list of messages. add_generation_prompt:bool extends tmplate for generation. """ pharia_template = """<|begin_of_text|>""" role_map = { "system": "<|start_header_id|>system<|end_header_id|>\n", "user": "<|start_header_id|>user<|end_header_id|>\n", "assistant": "<|start_header_id|>assistant<|end_header_id|>\n", } for message in messages: role = message["role"] content = message["content"] pharia_template += role_map.get(role, "") + content + "<|eot_id|>\n" if add_generation_prompt: pharia_template += "<|start_header_id|>assistant<|end_header_id|>\n" return pharia_template @spaces.GPU() def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1): try: text_streamer_a = TextIteratorStreamer(tokenizer_a, skip_prompt=True) text_streamer_b = TextIteratorStreamer(tokenizer_b, skip_prompt=True) system_prompt_list = [{"role": "system", "content": system_prompt}] if system_prompt else [] input_text_list = [{"role": "user", "content": input_text}] chat_history_a = [] for user, assistant in chatbot_a: chat_history_a.append({"role": "user", "content": user}) chat_history_a.append({"role": "assistant", "content": assistant}) chat_history_b = [] for user, assistant in chatbot_b: chat_history_b.append({"role": "user", "content": user}) chat_history_b.append({"role": "assistant", "content": assistant}) new_messages_a = system_prompt_list + chat_history_a + input_text_list new_messages_b = system_prompt_list + chat_history_b + input_text_list input_ids_a = tokenizer_a.apply_chat_template( new_messages_a, add_generation_prompt=True, dtype=torch.float16, return_tensors="pt" ).to(device) input_ids_b = tokenizer_b.apply_chat_template( new_messages_b, add_generation_prompt=True, dtype=torch.float16, return_tensors="pt" ).to(device) logging.debug(f'model_a.device: {model_a.device}, model_b.device: {model_b.device}') generation_kwargs_a = dict( input_ids=input_ids_a, streamer=text_streamer_a, max_new_tokens=max_new_tokens, pad_token_id=tokenizer_a.eos_token_id, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) generation_kwargs_b = dict( input_ids=input_ids_b, streamer=text_streamer_b, max_new_tokens=max_new_tokens, pad_token_id=tokenizer_b.eos_token_id, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) thread_a = Thread(target=model_a.generate, kwargs=generation_kwargs_a) thread_b = Thread(target=model_b.generate, kwargs=generation_kwargs_b) thread_a.start() thread_b.start() chatbot_a.append([input_text, ""]) chatbot_b.append([input_text, ""]) finished_a = False finished_b = False except Exception as e: logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') while not (finished_a and finished_b): if not finished_a: try: text_a = next(text_streamer_a) if tokenizer_a.eos_token in text_a: eot_location = text_a.find(tokenizer_a.eos_token) text_a = text_a[:eot_location] finished_a = True chatbot_a[-1][-1] += text_a yield chatbot_a, chatbot_b except StopIteration: finished_a = True except Exception as e: logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') if not finished_b: try: text_b = next(text_streamer_b) if tokenizer_b.eos_token in text_b: eot_location = text_b.find(tokenizer_b.eos_token) text_b = text_b[:eot_location] finished_b = True chatbot_b[-1][-1] += text_b yield chatbot_a, chatbot_b except StopIteration: finished_b = True except Exception as e: logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') return chatbot_a, chatbot_b def clear(): return [], [] def reveal_bot(selection, chatbot_a, chatbot_b): if selection == "Bot A kicks ass!": chatbot_a.append(["🏆", f"Thanks, man. I am {model_info[0]['name']}"]) chatbot_b.append(["💩", f"Pffff … I am {model_info[1]['name']}"]) elif selection == "Bot B crushes it!": chatbot_a.append(["🤡", f"Rigged … I am {model_info[0]['name']}"]) chatbot_b.append(["🥇", f"Well deserved! I am {model_info[1]['name']}"]) else: chatbot_a.append(["🤝", f"Lame … I am {model_info[0]['name']}"]) chatbot_b.append(["🤝", f"Dunno. I am {model_info[1]['name']}"]) return chatbot_a, chatbot_b arena_notes = """## Important Notes: - Sometimes an error may occur when generating the response, in this case, please try again. """ with gr.Blocks() as demo: try: with gr.Column(): gr.HTML("

🤖le Royale

") gr.Markdown(arena_notes) with gr.Row(variant="panel"): with gr.Column(scale=1): language_dropdown = gr.Dropdown( choices=["English", "German", "French", "Spanish"], label="Select Language for System Prompt", value="English" ) with gr.Column(): system_prompt = gr.Textbox( lines=1, label="System Prompt", value=system_prompts["English"], show_copy_button=True ) with gr.Row(variant="panel"): with gr.Column(scale=1): submit_btn = gr.Button(value="Generate", variant="primary") clear_btn = gr.Button(value="Clear", variant="secondary") input_text = gr.Textbox(lines=1, label="Prompt", value="Write a Nike style ad headline about the shame of being second best.", scale=3, show_copy_button=True) with gr.Row(variant="panel"): with gr.Column(): chatbot_a = gr.Chatbot(label="Model A", show_copy_button=True, height=500) with gr.Column(): chatbot_b = gr.Chatbot(label="Model B", show_copy_button=True, height=500) with gr.Row(variant="panel"): better_bot = gr.Radio(["Bot A kicks ass!", "Bot B crushes it!", "It's a draw."], label="Rate the output!") with gr.Accordion(label="Generation Configurations", open=False): max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=2048, label="Max New Tokens", step=128) temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature", step=0.01) top_p = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, label="Top-p", step=0.01) repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1) language_dropdown.change( lambda lang: system_prompts[lang], inputs=[language_dropdown], outputs=[system_prompt] ) better_bot.select(reveal_bot, inputs=[better_bot, chatbot_a, chatbot_b], outputs=[chatbot_a, chatbot_b]) #fckp outputs=[chatbot_a, chatbot_b] input_text.submit(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b]) submit_btn.click(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b]) clear_btn.click(clear, outputs=[chatbot_a, chatbot_b]) except Exception as e: logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') if __name__ == "__main__": demo.queue().launch()