Sampler-Arena / app.py
rwitz's picture
Update app.py
66afc6f
raw
history blame
4.07 kB
import gradio as gr
import requests
import os
import json
import random
import threading
from elo import update_elo_ratings # Custom function for ELO ratings
# Load the chatbot URLs and their respective model names from a JSON file
with open('chatbot_urls.json', 'r') as file:
chatbots = json.load(file)
# Thread-local storage for user-specific data
user_data = threading.local()
# Initialize or get user-specific ELO ratings
def get_user_elo_ratings():
if not hasattr(user_data, 'elo_ratings'):
user_data.elo_ratings = read_elo_ratings()
return user_data.elo_ratings
# Read ELO ratings from file (thread-safe)
def read_elo_ratings():
elo_ratings = {}
with threading.Lock():
try:
with open('elo_ratings.json', 'r') as file:
elo_ratings = json.load(file)
except FileNotFoundError:
elo_ratings = {model: 1200 for model in chatbots.keys()}
return elo_ratings
# Write ELO ratings to file (thread-safe)
def write_elo_ratings(elo_ratings):
with threading.Lock():
with open('elo_ratings.json', 'w') as file:
json.dump(elo_ratings, file, indent=4)
def get_bot_response(url, prompt):
payload = {
"input": {
"prompt": prompt,
"sampling_params": {
"max_new_tokens": 16,
"temperature": 0.7,
}
}
}
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": os.environ.get("RUNPOD_TOKEN")
}
response = requests.post(url, json=payload, headers=headers)
return response.json()
def chat_with_bots(user_input):
bot_names = list(chatbots.keys())
random.shuffle(bot_names)
bot1_url, bot2_url = chatbots[bot_names[0]], chatbots[bot_names[1]]
bot1_response = get_bot_response(bot1_url, user_input)
bot2_response = get_bot_response(bot2_url, user_input)
return bot1_response, bot2_response
def user_ask(state, chatbot, textbox):
user_input = textbox.value
bot1_response, bot2_response = chat_with_bots(user_input)
chatbot.append("User: " + user_input)
chatbot.append("Bot 1: " + bot1_response['output'])
chatbot.append("Bot 2: " + bot2_response['output'])
state['last_bots'] = [bot1_response['model_name'], bot2_response['model_name']]
return state, chatbot, ""
def update_ratings(state, winner_index):
elo_ratings = get_user_elo_ratings()
bot_names = list(chatbots.keys())
winner = state['last_bots'][winner_index]
loser = state['last_bots'][1 - winner_index]
elo_ratings = update_elo_ratings(elo_ratings, winner, loser)
write_elo_ratings(elo_ratings)
return f"Updated ELO ratings:\n{winner}: {elo_ratings[winner]}\n{loser}: {elo_ratings[loser]}"
def vote_up_model(state, chatbot):
update_message = update_ratings(state, 0)
chatbot.append(update_message)
return chatbot
def vote_down_model(state, chatbot):
update_message = update_ratings(state, 1)
chatbot.append(update_message)
return chatbot
with gr.Blocks() as demo:
state = gr.State({})
with gr.Row():
with gr.Column(scale=0.5):
chatbot = gr.Chatbot(label='ChatBox')
with gr.Row():
textbox = gr.Textbox(placeholder="Enter text and press ENTER")
submit_btn = gr.Button(value="Submit")
with gr.Column():
upvote_btn = gr.Button(value="πŸ‘ Upvote Bot 1")
downvote_btn = gr.Button(value="πŸ‘Ž Upvote Bot 2")
clear_btn = gr.Button(value="πŸ—‘οΈ Clear Chat")
textbox.submit(user_ask, [state, chatbot, textbox], [state, chatbot, textbox])
submit_btn.click(user_ask, [state, chatbot, textbox], [state, chatbot, textbox])
upvote_btn.click(vote_up_model, [state, chatbot], [chatbot])
downvote_btn.click(vote_down_model, [state, chatbot], [chatbot])
clear_btn.click(lambda _: chatbot.clear(), inputs=[], outputs=[chatbot])
demo.launch(share=True, enable_queue=True, server_name='0.0.0.0', server_port=7860)