Spaces:
Sleeping
Sleeping
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
import gradio as gr | |
import logging | |
from huggingface_hub import login | |
import os | |
import traceback | |
from threading import Thread | |
logging.basicConfig(level=logging.DEBUG) | |
SPACER = '\n' + '*' * 40 + '\n' | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
login(token=HF_TOKEN) | |
model_a_info = {"id": "NousResearch/Meta-Llama-3.1-8B-Instruct", | |
"name": "Meta Llama 3.1 8B Instruct"} | |
model_b_info = {"id": "mistralai/Mistral-7B-Instruct-v0.3", | |
"name": "Mistral 7B Instruct v0.3"} | |
device = "cuda" | |
try: | |
tokenizer_a = AutoTokenizer.from_pretrained(model_a_info['id']) | |
model_a = AutoModelForCausalLM.from_pretrained( | |
model_a_info['id'], | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
#model_a.tie_weights() | |
tokenizer_b = AutoTokenizer.from_pretrained(model_b_info['id']) | |
model_b = AutoModelForCausalLM.from_pretrained( | |
model_b_info['id'], | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
model_b.tie_weights() | |
except Exception as e: | |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
def apply_chat_template(messages, add_generation_prompt=False): | |
""" | |
Function to apply the chat template manually for each message in a list. | |
messages: List of dictionaries, each containing a 'role' and 'content'. | |
""" | |
pharia_template = """<|begin_of_text|>""" | |
role_map = { | |
"system": "<|start_header_id|>system<|end_header_id|>\n", | |
"user": "<|start_header_id|>user<|end_header_id|>\n", | |
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n", | |
} | |
# Iterate through the messages and apply the template for each role | |
for message in messages: | |
role = message["role"] | |
content = message["content"] | |
pharia_template += role_map.get(role, "") + content + "<|eot_id|>\n" | |
# Add the assistant generation prompt if required | |
if add_generation_prompt: | |
pharia_template += "<|start_header_id|>assistant<|end_header_id|>\n" | |
return pharia_template | |
def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1): | |
try: | |
text_streamer_a = TextIteratorStreamer(tokenizer_a, skip_prompt=True) | |
text_streamer_b = TextIteratorStreamer(tokenizer_b, skip_prompt=True) | |
system_prompt_list = [{"role": "system", "content": system_prompt}] if system_prompt else [] | |
input_text_list = [{"role": "user", "content": input_text}] | |
chat_history_a = [] | |
for user, assistant in chatbot_a: | |
chat_history_a.append({"role": "user", "content": user}) | |
chat_history_a.append({"role": "assistant", "content": assistant}) | |
chat_history_b = [] | |
for user, assistant in chatbot_b: | |
chat_history_b.append({"role": "user", "content": user}) | |
chat_history_b.append({"role": "assistant", "content": assistant}) | |
new_messages_a = system_prompt_list + chat_history_a + input_text_list | |
new_messages_b = system_prompt_list + chat_history_b + input_text_list | |
input_ids_a = tokenizer_a.apply_chat_template( | |
new_messages_a, | |
add_generation_prompt=True, | |
dtype=torch.float16, | |
return_tensors="pt" | |
).to(device) | |
input_ids_b = tokenizer_b.apply_chat_template( | |
new_messages_b, | |
add_generation_prompt=True, | |
dtype=torch.float16, | |
return_tensors="pt" | |
).to(device) | |
logging.debug(f'model_a.device: {model_a.device}, model_b.device: {model_b.device}') | |
generation_kwargs_a = dict( | |
input_ids=input_ids_a, | |
streamer=text_streamer_a, | |
max_new_tokens=max_new_tokens, | |
pad_token_id=tokenizer_a.eos_token_id, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
) | |
generation_kwargs_b = dict( | |
input_ids=input_ids_b, | |
streamer=text_streamer_b, | |
max_new_tokens=max_new_tokens, | |
pad_token_id=tokenizer_b.eos_token_id, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
) | |
thread_a = Thread(target=model_a.generate, kwargs=generation_kwargs_a) | |
thread_b = Thread(target=model_b.generate, kwargs=generation_kwargs_b) | |
thread_a.start() | |
thread_b.start() | |
chatbot_a.append([input_text, ""]) | |
chatbot_b.append([input_text, ""]) | |
finished_a = False | |
finished_b = False | |
except Exception as e: | |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
while not (finished_a and finished_b): | |
if not finished_a: | |
try: | |
text_a = next(text_streamer_a) | |
if tokenizer_a.eos_token in text_a: | |
eot_location = text_a.find(tokenizer_a.eos_token) | |
text_a = text_a[:eot_location] | |
finished_a = True | |
chatbot_a[-1][-1] += text_a | |
yield chatbot_a, chatbot_b | |
except StopIteration: | |
finished_a = True | |
except Exception as e: | |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
if not finished_b: | |
try: | |
text_b = next(text_streamer_b) | |
if tokenizer_b.eos_token in text_b: | |
eot_location = text_b.find(tokenizer_b.eos_token) | |
text_b = text_b[:eot_location] | |
finished_b = True | |
chatbot_b[-1][-1] += text_b | |
yield chatbot_a, chatbot_b | |
except StopIteration: | |
finished_b = True | |
except Exception as e: | |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
return chatbot_a, chatbot_b | |
def clear(): | |
return [], [] | |
def reveal_bot(selection, chatbot_a, chatbot_b): | |
if selection == "Bot A kicks ass!": | |
chatbot_a.append(["π", f"Thanks, man. I am {model_a_info['name']}"]) | |
chatbot_b.append(["π©", f"Pffff β¦ I am {model_b_info['name']}"]) | |
elif selection == "Bot B crushes it!": | |
chatbot_a.append(["π€‘", f"Rigged β¦ I am {model_a_info['name']}"]) | |
chatbot_b.append(["π₯", f"Well deserved! I am {model_b_info['name']}"]) | |
else: | |
chatbot_a.append(["π€", f"Lame β¦ I am {model_a_info['name']}"]) | |
chatbot_b.append(["π€", f"Dunno. I am {model_b_info['name']}"]) | |
return chatbot_a, chatbot_b | |
arena_notes = """## Important Notes: | |
- Sometimes an error may occur when generating the response, in this case, please try again. | |
""" | |
with gr.Blocks() as demo: | |
try: | |
with gr.Column(): | |
gr.HTML("<center><h1>π€le Royale</h1></center>") | |
gr.Markdown(arena_notes) | |
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="You are a helpful chatbot that adheres to the prompted request.", show_copy_button=True) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
submit_btn = gr.Button(value="Generate", variant="primary") | |
clear_btn = gr.Button(value="Clear", variant="secondary") | |
input_text = gr.Textbox(lines=1, label="Prompt", value="Write a Nike style ad headline about the shame of being second best.", scale=3, show_copy_button=True) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
chatbot_a = gr.Chatbot(label="Model A", show_copy_button=True, height=500) | |
with gr.Column(): | |
chatbot_b = gr.Chatbot(label="Model B", show_copy_button=True, height=500) | |
with gr.Row(variant="panel"): | |
better_bot = gr.Radio(["Bot A kicks ass!", "Bot B crushes it!", "It's a draw."], label="Rate the output!") | |
with gr.Accordion(label="Generation Configurations", open=False): | |
max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=2048, label="Max New Tokens", step=128) | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature", step=0.01) | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, label="Top-p", step=0.01) | |
repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1) | |
better_bot.select(reveal_bot, inputs=[better_bot, chatbot_a, chatbot_b], outputs=[chatbot_a, chatbot_b]) #fckp outputs=[chatbot_a, chatbot_b] | |
input_text.submit(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b]) | |
submit_btn.click(generate_both, inputs=[system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[chatbot_a, chatbot_b]) | |
clear_btn.click(clear, outputs=[chatbot_a, chatbot_b]) | |
except Exception as e: | |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}') | |
if __name__ == "__main__": | |
demo.queue().launch() |