Sampler-Arena / app.py
rwitz's picture
Update app.py
0d37169 verified
raw
history blame
17.1 kB
import gradio as gr
import requests
import os
import pandas as pd
import json
import ssl
import random
from elo import update_elo_ratings # Custom function for ELO ratings
enable_btn = gr.Button.update(interactive=True)
import sqlite3
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
async def direct_regenerate(model, user_input, chatbot, character_name, character_description, user_name):
adapter = next(entry['adapter'] for entry in chatbots_data if entry['original_model'] == model)
temp_state = {
"history": [
[{"role": "user", "content": chatbot[-1][0]}] # Keep the user's last message
]
}
response = await get_bot_response(adapter, user_input, temp_state, 0, character_name, character_description, user_name)
chatbot[-1] = (chatbot[-1][0], response) # Update only the assistant's response
return "", chatbot
password=os.environ.get("MONGODB")
def init_database():
uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
client = MongoClient(uri)
db = client["elo_ratings"]
collection = db["elo_ratings"]
return collection
def get_user_elo_ratings(collection):
rows = list(collection.find())
if rows:
return {row['bot_name']: {'elo_rating': row['elo_rating'], 'games_played': row['games_played']} for row in rows}
else:
return {"default": {'elo_rating': 1200, 'games_played': 0}}
def update_elo_rating(collection, updated_ratings, winner, loser):
collection.update_one({"bot_name": winner}, {"$set": {"elo_rating": updated_ratings[winner]['elo_rating'], "games_played": updated_ratings[winner]['games_played']}}, upsert=True)
collection.update_one({"bot_name": loser}, {"$set": {"elo_rating": updated_ratings[loser]['elo_rating'], "games_played": updated_ratings[loser]['games_played']}}, upsert=True)
import json
with open('chatbots.txt', 'r') as file:
chatbots_data = json.load(file)
chatbots = [entry['adapter'] for entry in chatbots_data]
def clear_chat(state):
# Reset state including the chatbot order
state = {} if state is not None else state
# Initialize the collection object
collection = init_database()
# Get the list of adapter names
adapter_names = [entry['adapter'] for entry in chatbots_data]
# Randomly select two new adapters
selected_adapters = random.sample(adapter_names, 2)
state['last_bots'] = selected_adapters
# Reset other components specific to the "Chatbot Arena" tab
return state, [], [], gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True)
from datasets import load_dataset,DatasetDict,Dataset
import requests
import os
# Function to get bot response
def format_prompt(state, bot_index, character_name, character_description, user_name, num_messages=20):
if character_name is None or character_name.strip() == "":
character_name = "Ryan"
if character_description is None or character_description.strip() == "":
character_description = "Ryan is a friendly and knowledgeable college student who is always willing to help. He has a strong background in math and coding, and enjoys engaging in intellectual discussions on a wide range of topics."
if user_name is None or user_name.strip() == "":
user_name = "You"
prompt = f"The following is a conversation between {user_name} and {character_name}.\n\n"
prompt += f"{character_name}'s background:\n{character_description}\n\n"
# Get the last num_messages messages from the conversation history
recent_messages = state["history"][bot_index][-num_messages:]
for message in recent_messages:
if message['role'] == 'user':
prompt += f"{user_name}: {message['content']}\n"
else:
prompt += f"{character_name}: {message['content']}\n"
prompt += f"{character_name}:"
return prompt
import aiohttp
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential
async def get_bot_response(adapter_id, prompt, state, bot_index, character_name, character_description, user_name):
prompt = format_prompt(state, bot_index, character_name, character_description, user_name)
fireworks_adapter_name = next(entry['fireworks_adapter_name'] for entry in chatbots_data if entry['adapter'] == adapter_id)
url = "https://api.fireworks.ai/inference/v1/completions"
payload = {
"model": f"accounts/gaingg19-432d9f/models/{fireworks_adapter_name}",
"max_tokens": 250,
"temperature": 0.7,
"prompt": prompt,
"stop": ["<|im_end|>",f"{character_name}:",f"{user_name}:"]
}
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=payload, headers=headers, timeout=30) as response:
if response.status == 200:
response_data = await response.json()
response_text = response_data['choices'][0]['text']
else:
error_text = await response.text()
print(error_text)
response_text = "Sorry, I couldn't generate a response."
except (aiohttp.ClientError, asyncio.TimeoutError):
response_text = "Sorry, I couldn't generate a response."
return response_text.strip()
async def chat_with_bots(user_input, state, character_name, character_description, user_name):
# Use existing bot order from state if available, otherwise shuffle and initialize
if 'last_bots' not in state or not state['last_bots']:
random.shuffle(chatbots)
state['last_bots'] = [chatbots[0], chatbots[1]]
bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1]
bot1_response, bot2_response = await asyncio.gather(
get_bot_response(bot1_adapter, user_input, state, 0, character_name, character_description, user_name),
get_bot_response(bot2_adapter, user_input, state, 1, character_name, character_description, user_name)
)
return bot1_response.replace("<|im_end|>",""), bot2_response.replace("<|im_end|>","")
def update_ratings(state, winner_index, collection):
elo_ratings = get_user_elo_ratings(collection)
winner_adapter = state['last_bots'][winner_index]
loser_adapter = state['last_bots'][1 - winner_index]
winner = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == winner_adapter)
loser = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == loser_adapter)
elo_ratings = update_elo_ratings(elo_ratings, winner_adapter, loser_adapter)
update_elo_rating(collection, elo_ratings, winner_adapter, loser_adapter)
return [('Winner: ', winner), ('Loser: ', loser)]
def vote_up_model(state, chatbot, chatbot2):
collection = init_database()
update_message = update_ratings(state, 0, collection)
chatbot.append(update_message[0])
chatbot2.append(update_message[1])
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
def vote_down_model(state, chatbot, chatbot2):
collection = init_database()
update_message = update_ratings(state, 1, collection)
chatbot2.append(update_message[0])
chatbot.append(update_message[1])
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
async def user_ask(state, chatbot1, chatbot2, textbox, character_name, character_description, user_name):
if character_name and len(character_name) > 20:
character_name = character_name[:20] # Limit character name to 20 characters
if character_description and len(character_description) > 500:
character_description = character_description[:500] # Limit character description to 200 characters
if user_name and len(user_name) > 20:
user_name = user_name[:20] # Limit user name to 20 characters
global enable_btn
user_input = textbox
if len(user_input) > 500:
user_input = user_input[:500] # Limit user input to 200 characters
collection = init_database() # Initialize the collection object
# Keep only the last 10 messages in history
# Updating state with the current ELO ratings
state["elo_ratings"] = get_user_elo_ratings(collection)
if "history" not in state:
state.update({'history': [[],[]]})
state["history"][0].extend([
{"role": "user", "content": user_input}])
state["history"][1].extend([
{"role": "user", "content": user_input}])
if len(state["history"][0])>20:
state["history"][0] = state["history"][0][-20:]
state["history"][1] = state["history"][1][-20:]
# Chat with bots
bot1_response, bot2_response = await chat_with_bots(user_input, state, character_name, character_description, user_name)
state["history"][0].extend([
{"role": "bot1", "content": bot1_response},
])
state["history"][1].extend([
{"role": "bot2", "content": bot2_response},
])
chatbot1.append((user_input,bot1_response))
chatbot2.append((user_input,bot2_response))
# Keep only the last 10 messages in history
# Format the conversation in ChatML format
return state, chatbot1, chatbot2, gr.update(value=''),enable_btn,enable_btn
import pandas as pd
# Function to generate leaderboard data
import requests
def submit_model(model_name):
discord_url = os.environ.get("DISCORD_URL")
if discord_url:
payload = {
"content": f"New model submitted: {model_name}"
}
response = requests.post(discord_url, json=payload)
if response.status_code == 204:
return "Model submitted successfully!"
else:
return "Failed to submit the model."
else:
return "Discord webhook URL not configured."
def generate_leaderboard(collection):
rows = list(collection.find())
leaderboard_data = pd.DataFrame(rows, columns=['bot_name', 'elo_rating', 'games_played'])
leaderboard_data['original_model'] = leaderboard_data['bot_name'].apply(lambda x: next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == x))
leaderboard_data = leaderboard_data[['original_model', 'elo_rating', 'games_played']]
leaderboard_data.columns = ['Chatbot', 'ELO Score', 'Games Played']
leaderboard_data['ELO Score'] = leaderboard_data['ELO Score'].round().astype(int)
leaderboard_data = leaderboard_data.sort_values('ELO Score', ascending=False)
return leaderboard_data
def refresh_leaderboard():
collection = init_database()
leaderboard_data = generate_leaderboard(collection)
return leaderboard_data
async def direct_chat(model, user_input, state, chatbot, character_name, character_description, user_name):
adapter = next(entry['adapter'] for entry in chatbots_data if entry['original_model'] == model)
if "direct_history" not in state:
state["direct_history"] = []
if len(state["direct_history"])>20:
state["direct_history"] = state["direct_history"][-20:]
state["direct_history"].append({"role": "user", "content": user_input})
temp_state = {
"history": [
state["direct_history"],
state["direct_history"]
]
}
response = await get_bot_response(adapter, user_input, temp_state, 0, character_name, character_description, user_name)
chatbot.append((user_input, response))
state["direct_history"].append({"role": "bot", "content": response})
return "", chatbot, state
def reset_direct_chat(state):
state["direct_history"] = []
return [], gr.Textbox.update(value=''), state
refresh_leaderboard()
# Gradio interface setup
# Gradio interface setup
with gr.Blocks() as demo:
state = gr.State({})
with gr.Tab("πŸ€– Chatbot Arena"):
gr.Markdown("## πŸ₯Š Let's see which chatbot wins!")
with gr.Row():
with gr.Column():
chatbot1 = gr.Chatbot(label='πŸ€– Model A').style(height=350)
upvote_btn_a = gr.Button(value="πŸ‘ Upvote A", interactive=False).style(full_width=True)
with gr.Column():
chatbot2 = gr.Chatbot(label='πŸ€– Model B').style(height=350)
upvote_btn_b = gr.Button(value="πŸ‘ Upvote B", interactive=False).style(full_width=True)
with gr.Row():
with gr.Column(scale=5):
textbox = gr.Textbox(placeholder="🎀 Enter your prompt (up to 500 characters)")
submit_btn = gr.Button(value="Submit")
with gr.Row():
reset_btn = gr.Button(value="πŸ—‘οΈ Reset")
with gr.Row():
character_name = gr.Textbox(label="Character Name", value="Ryan", placeholder="Enter character name (max 20 chars)")
character_description = gr.Textbox(label="Character Description", value="Ryan is a college student who is always willing to help. He knows a lot about math and coding.", placeholder="Enter character description (max 500 chars)")
with gr.Row():
user_name = gr.Textbox(label="Your Name", value="You", placeholder="Enter your name (max 20 chars)")
# ...
reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True)
textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True)
collection = init_database()
upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
with gr.Tab("πŸ’¬ Direct Chat"):
gr.Markdown("## πŸ—£οΈ Chat directly with a model!")
with gr.Row():
model_dropdown = gr.Dropdown(choices=[entry['original_model'] for entry in chatbots_data], value=chatbots_data[0]['original_model'], label="πŸ€– Select a model")
with gr.Row():
direct_chatbot = gr.Chatbot(label="πŸ’¬ Direct Chat").style(height=500)
with gr.Row():
with gr.Column(scale=5):
direct_textbox = gr.Textbox(placeholder="πŸ’­ Enter your message")
direct_submit_btn = gr.Button(value="Submit")
with gr.Row():
direct_regenerate_btn = gr.Button(value="πŸ”„ Regenerate")
direct_reset_btn = gr.Button(value="πŸ—‘οΈ Reset Chat")
# ...
direct_regenerate_btn.click(direct_regenerate, inputs=[model_dropdown, direct_textbox, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot])
direct_textbox.submit(direct_chat, inputs=[model_dropdown, direct_textbox, state, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot, state])
direct_submit_btn.click(direct_chat, inputs=[model_dropdown, direct_textbox, state, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot, state])
direct_reset_btn.click(reset_direct_chat, inputs=[state], outputs=[direct_chatbot, direct_textbox, state])
with gr.Tab("πŸ† Leaderboard"):
gr.Markdown("## πŸ“Š Check out the top-performing models!")
try:
leaderboard = gr.Dataframe(refresh_leaderboard())
except:
leaderboard = gr.Dataframe()
with gr.Row():
refresh_btn = gr.Button("πŸ”„ Refresh Leaderboard")
refresh_btn.click(refresh_leaderboard, outputs=[leaderboard])
with gr.Tab("πŸ“¨ Submit Model"):
gr.Markdown("## πŸ“¨ Submit a new model to be added to the chatbot arena!")
with gr.Row():
model_name_input = gr.Textbox(placeholder="Enter the model name")
submit_model_btn = gr.Button(value="Submit Model")
submit_model_btn.click(submit_model, inputs=[model_name_input], outputs=[model_name_input])
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch(share=False)