Sampler-Arena / app.py
rwitz's picture
Update app.py
f1d1dd8
raw
history blame
6.95 kB
import gradio as gr
import requests
import os
import json
import random
from elo import update_elo_ratings # Custom function for ELO ratings
enable_btn = gr.Button.update(interactive=True)
# Load chatbot URLs and model names from a JSON file
with open('chatbot_urls.json', 'r') as file:
chatbots = json.load(file)
def clear_chat(state):
if state is not None:
state = {}
return state, None,None,gr.Button.update(interactive=False),gr.Button.update(interactive=False)
# Initialize or get user-specific ELO ratings
def get_user_elo_ratings(state):
return state['elo_ratings']
# Read and write ELO ratings to file (thread-safe)
def read_elo_ratings():
try:
with open('elo_ratings.json', 'r') as file:
return json.load(file)
except FileNotFoundError:
return {model: 1200 for model in chatbots.keys()}
def write_elo_ratings(elo_ratings):
with open('elo_ratings.json', 'w') as file:
json.dump(elo_ratings, file, indent=4)
# Function to get bot response
def format_alpaca_prompt(state):
alpaca_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
alpaca_prompt2 = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
for message in state["history"][0]:
j=""
if message['role']=='user':
j="### Instruction:\n"
else:
j="### Response:\n"
alpaca_prompt += j+ message['content']+"\n\n"
for message in state["history"][1]:
j=""
if message['role']=='user':
j="### Instruction:\n"
else:
j="### Response:\n"
alpaca_prompt2 += j+ message['content']+"\n\n"
return [alpaca_prompt+"### Response:\n",alpaca_prompt2+"### Response:\n"]
def get_bot_response(url, prompt,state,bot_index):
alpaca_prompt = format_alpaca_prompt(state)
payload = {
"input": {
"prompt": alpaca_prompt[bot_index],
"sampling_params": {
"max_new_tokens": 50,
"temperature": 0.7,
"top_p":0.95
}
}
}
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": os.environ.get("RUNPOD_TOKEN")
}
response = requests.post(url, json=payload, headers=headers)
return response.json()['output'].split('### Instruction')[0]
def chat_with_bots(user_input, state):
bot_names = list(chatbots.keys())
random.shuffle(bot_names)
bot1_url, bot2_url = chatbots[bot_names[0]], chatbots[bot_names[1]]
# Update the state with the names of the last bots
state.update({'last_bots': [bot_names[0], bot_names[1]]})
bot1_response = get_bot_response(bot1_url, user_input,state,0)
bot2_response = get_bot_response(bot2_url, user_input,state,1)
return bot1_response, bot2_response
def update_ratings(state, winner_index):
elo_ratings = get_user_elo_ratings(state)
bot_names = list(chatbots.keys())
winner = state['last_bots'][winner_index]
loser = state['last_bots'][1 - winner_index]
elo_ratings = update_elo_ratings(elo_ratings, winner, loser)
write_elo_ratings(elo_ratings)
return f"Updated ELO ratings:\n{winner}: {elo_ratings[winner]}\n{loser}: {elo_ratings[loser]}"
def vote_up_model(state, chatbot):
update_message = update_ratings(state, 0)
chatbot.append(update_message)
return chatbot, gr.Button.update(interactive=False),gr.Button.update(interactive=False) # Disable voting buttons
def user_ask(state, chatbot1, chatbot2, textbox):
global enable_btn
user_input = textbox
if len(user_input) > 200:
user_input = user_input[:200] # Limit user input to 200 characters
# Updating state with the current ELO ratings
state["elo_ratings"] = read_elo_ratings()
if "history" not in state:
state.update({'history': [[],[]]})
state["history"][0].extend([
{"role": "user", "content": user_input}])
state["history"][1].extend([
{"role": "user", "content": user_input}])
# Chat with bots
bot1_response, bot2_response = chat_with_bots(user_input, state)
state["history"][0].extend([
{"role": "bot1", "content": bot1_response},
])
state["history"][1].extend([
{"role": "bot2", "content": bot2_response},
])
chatbot1.append((user_input,bot1_response))
chatbot2.append((user_input,bot2_response))
# Keep only the last 10 messages in history
state["history"] = state["history"][-10:]
# Format the conversation in ChatML format
return state, chatbot1, chatbot2, textbox,enable_btn,enable_btn
# Gradio interface setup
with gr.Blocks() as demo:
state = gr.State({})
with gr.Tab("Chatbot Arena"):
with gr.Row():
with gr.Column():
chatbot1 = gr.Chatbot(label='Model A').style(height=600)
upvote_btn_a = gr.Button(value="πŸ‘ Upvote A",interactive=False)
with gr.Column():
chatbot2 = gr.Chatbot(label='Model B').style(height=600)
upvote_btn_b = gr.Button(value="πŸ‘ Upvote B",interactive=False)
textbox = gr.Textbox(placeholder="Enter your prompt (up to 200 characters)", max_chars=200)
with gr.Row():
submit_btn = gr.Button(value="Send")
reset_btn = gr.Button(value="Reset")
reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b])
textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox,upvote_btn_a,upvote_btn_b])
submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox,upvote_btn_a,upvote_btn_b])
upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1], outputs=[chatbot1,upvote_btn_a,upvote_btn_b])
upvote_btn_b.click(vote_down_model, inputs=[state, chatbot2], outputs=[chatbot2,upvote_btn_a,upvote_btn_b])
with gr.Tab("Leaderboard"):
leaderboard = gr.Dataframe()
refresh_btn = gr.Button("Refresh Leaderboard")
# Function to refresh leaderboard
def refresh_leaderboard():
return generate_leaderboard()
# Event handler for the refresh button
refresh_btn.click(refresh_leaderboard, inputs=[], outputs=[leaderboard])
# Launch the Gradio interface
demo.launch()
import pandas as pd
# Function to generate leaderboard data
def generate_leaderboard():
elo_ratings = read_elo_ratings() # Assuming this function returns a dict of {bot_name: elo_score}
leaderboard_data = pd.DataFrame(list(elo_ratings.items()), columns=['Chatbot', 'ELO Score'])
leaderboard_data = leaderboard_data.sort_values('ELO Score', ascending=False)
return leaderboard_data