Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import os | |
import json | |
import random | |
from elo import update_elo_ratings # Custom function for ELO ratings | |
enable_btn = gr.Button.update(interactive=True) | |
# Load chatbot URLs and model names from a JSON file | |
with open('chatbot_urls.json', 'r') as file: | |
chatbots = json.load(file) | |
def clear_chat(state): | |
if state is not None: | |
state = {} | |
return state, None,None,gr.Button.update(interactive=False),gr.Button.update(interactive=False) | |
# Initialize or get user-specific ELO ratings | |
def get_user_elo_ratings(state): | |
return state['elo_ratings'] | |
# Read and write ELO ratings to file (thread-safe) | |
def read_elo_ratings(): | |
try: | |
with open('elo_ratings.json', 'r') as file: | |
return json.load(file) | |
except FileNotFoundError: | |
return {model: 1200 for model in chatbots.keys()} | |
def write_elo_ratings(elo_ratings): | |
with open('elo_ratings.json', 'w') as file: | |
json.dump(elo_ratings, file, indent=4) | |
# Function to get bot response | |
def format_alpaca_prompt(state): | |
alpaca_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
alpaca_prompt2 = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
for message in state["history"][0]: | |
j="" | |
if message['role']=='user': | |
j="### Instruction:\n" | |
else: | |
j="### Response:\n" | |
alpaca_prompt += j+ message['content']+"\n\n" | |
for message in state["history"][1]: | |
j="" | |
if message['role']=='user': | |
j="### Instruction:\n" | |
else: | |
j="### Response:\n" | |
alpaca_prompt2 += j+ message['content']+"\n\n" | |
return [alpaca_prompt+"### Response:\n",alpaca_prompt2+"### Response:\n"] | |
def get_bot_response(url, prompt,state,bot_index): | |
alpaca_prompt = format_alpaca_prompt(state) | |
payload = { | |
"input": { | |
"prompt": alpaca_prompt[bot_index], | |
"sampling_params": { | |
"max_new_tokens": 50, | |
"temperature": 0.7, | |
"top_p":0.95 | |
} | |
} | |
} | |
headers = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
"authorization": os.environ.get("RUNPOD_TOKEN") | |
} | |
response = requests.post(url, json=payload, headers=headers) | |
return response.json()['output'].split('### Instruction')[0] | |
def chat_with_bots(user_input, state): | |
bot_names = list(chatbots.keys()) | |
random.shuffle(bot_names) | |
bot1_url, bot2_url = chatbots[bot_names[0]], chatbots[bot_names[1]] | |
# Update the state with the names of the last bots | |
state.update({'last_bots': [bot_names[0], bot_names[1]]}) | |
bot1_response = get_bot_response(bot1_url, user_input,state,0) | |
bot2_response = get_bot_response(bot2_url, user_input,state,1) | |
return bot1_response, bot2_response | |
def update_ratings(state, winner_index): | |
elo_ratings = get_user_elo_ratings(state) | |
bot_names = list(chatbots.keys()) | |
winner = state['last_bots'][winner_index] | |
loser = state['last_bots'][1 - winner_index] | |
elo_ratings = update_elo_ratings(elo_ratings, winner, loser) | |
write_elo_ratings(elo_ratings) | |
return [('Winner: ',state['last_bots'][winner_index]),('Loser: ',state['last_bots'][1 - winner_index])] | |
def vote_up_model(state, chatbot,chatbot2): | |
update_message = update_ratings(state, 0) | |
chatbot.append(update_message[0]) | |
chatbot2.append(update_message[1]) | |
return chatbot, chatbot2,gr.Button.update(interactive=False),gr.Button.update(interactive=False),gr.Textbox.update(interactive=False),gr.Button.update(interactive=False) # Disable voting buttons | |
def vote_down_model(state, chatbot,chatbot2): | |
update_message = update_ratings(state, 1) | |
chatbot2.append(update_message[0]) | |
chatbot.append(update_message[1]) | |
return chatbot,chatbot2, gr.Button.update(interactive=False),gr.Button.update(interactive=False),gr.Textbox.update(interactive=False),gr.Button.update(interactive=False) # Disable voting buttons | |
def user_ask(state, chatbot1, chatbot2, textbox): | |
global enable_btn | |
user_input = textbox | |
if len(user_input) > 200: | |
user_input = user_input[:200] # Limit user input to 200 characters | |
# Updating state with the current ELO ratings | |
state["elo_ratings"] = read_elo_ratings() | |
if "history" not in state: | |
state.update({'history': [[],[]]}) | |
state["history"][0].extend([ | |
{"role": "user", "content": user_input}]) | |
state["history"][1].extend([ | |
{"role": "user", "content": user_input}]) | |
# Chat with bots | |
bot1_response, bot2_response = chat_with_bots(user_input, state) | |
state["history"][0].extend([ | |
{"role": "bot1", "content": bot1_response}, | |
]) | |
state["history"][1].extend([ | |
{"role": "bot2", "content": bot2_response}, | |
]) | |
chatbot1.append((user_input,bot1_response)) | |
chatbot2.append((user_input,bot2_response)) | |
# Keep only the last 10 messages in history | |
state["history"] = state["history"][-10:] | |
# Format the conversation in ChatML format | |
return state, chatbot1, chatbot2, textbox,enable_btn,enable_btn | |
import pandas as pd | |
# Function to generate leaderboard data | |
def generate_leaderboard(): | |
elo_ratings = read_elo_ratings() # Assuming this function returns a dict of {bot_name: elo_score} | |
leaderboard_data = pd.DataFrame(list(elo_ratings.items()), columns=['Chatbot', 'ELO Score']) | |
leaderboard_data = leaderboard_data.sort_values('ELO Score', ascending=False) | |
return leaderboard_data | |
# Gradio interface setup | |
with gr.Blocks() as demo: | |
state = gr.State({}) | |
with gr.Tab("Chatbot Arena"): | |
with gr.Row(): | |
with gr.Column(): | |
chatbot1 = gr.Chatbot(label='Model A').style(height=600) | |
upvote_btn_a = gr.Button(value="π Upvote A",interactive=False) | |
with gr.Column(): | |
chatbot2 = gr.Chatbot(label='Model B').style(height=600) | |
upvote_btn_b = gr.Button(value="π Upvote B",interactive=False) | |
textbox = gr.Textbox(placeholder="Enter your prompt (up to 200 characters)", max_chars=200) | |
with gr.Row(): | |
submit_btn = gr.Button(value="Send") | |
reset_btn = gr.Button(value="Reset") | |
reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b]) | |
textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox,upvote_btn_a,upvote_btn_b]) | |
submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox,upvote_btn_a,upvote_btn_b]) | |
upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1,chatbot2], outputs=[chatbot1,chatbot2,upvote_btn_a,upvote_btn_b,textbox,submit_btn]) | |
upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1,chatbot2], outputs=[chatbot1,chatbot2,upvote_btn_a,upvote_btn_b,textbox,submit_btn]) | |
with gr.Tab("Leaderboard"): | |
leaderboard = gr.Dataframe(generate_leaderboard()) | |
refresh_btn = gr.Button("Refresh Leaderboard") | |
# Function to refresh leaderboard | |
def refresh_leaderboard(): | |
return generate_leaderboard() | |
# Event handler for the refresh button | |
refresh_btn.click(refresh_leaderboard, inputs=[], outputs=[leaderboard]) | |
# Launch the Gradio interface | |
demo.launch() | |