Sampler-Arena / app.py
rwitz's picture
Update app.py
bde2561 verified
raw
history blame
13.7 kB
import gradio as gr
import requests
import os
import pandas as pd
import json
import ssl
import random
from elo import update_elo_ratings # Custom function for ELO ratings
enable_btn = gr.Button.update(interactive=True)
import sqlite3
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
async def direct_regenerate(model, user_input, chatbot):
temp_state = {
"history": [
[{"role": "user", "content": user_input}],
[{"role": "user", "content": user_input}]
]
}
response = await get_bot_response(model, user_input, temp_state, 0)
chatbot[-1] = (user_input, response)
return "", chatbot
async def regenerate_responses(state, chatbot1, chatbot2):
user_input = state["history"][0][-2]["content"]
bot1_response = await get_bot_response(state['last_bots'][0], user_input, state, 0)
bot2_response = await get_bot_response(state['last_bots'][1], user_input, state, 1)
chatbot1[-1] = (user_input, bot1_response)
chatbot2[-1] = (user_input, bot2_response)
return chatbot1, chatbot2
password=os.environ.get("MONGODB")
def reset_database():
uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
client = MongoClient(uri)
db = client["elo_ratings"]
db.drop_collection("elo_ratings")
return "Database reset successfully!"
def init_database():
uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
client = MongoClient(uri)
db = client["elo_ratings"]
collection = db["elo_ratings"]
return collection
def get_user_elo_ratings(collection):
rows = list(collection.find())
if rows:
return {row['bot_name']: {'elo_rating': row['elo_rating'], 'games_played': row['games_played']} for row in rows}
else:
return {"default": {'elo_rating': 1200, 'games_played': 0}}
def update_elo_rating(collection, updated_ratings, winner, loser):
collection.update_one({"bot_name": winner}, {"$set": {"elo_rating": updated_ratings[winner]['elo_rating'], "games_played": updated_ratings[winner]['games_played']}}, upsert=True)
collection.update_one({"bot_name": loser}, {"$set": {"elo_rating": updated_ratings[loser]['elo_rating'], "games_played": updated_ratings[loser]['games_played']}}, upsert=True)
# Load chatbot URLs and model names from a JSON file
# Load chatbot model adapter names from a text file
with open('chatbots.txt', 'r') as file:
chatbots = file.read().splitlines()
def clear_chat(state):
# Reset state including the chatbot order
state = {} if state is not None else state
# Initialize the collection object
collection = init_database()
# Get the list of chatbot names
bot_names = list(get_user_elo_ratings(collection).keys())
# Randomly select two new Loras
selected_bots = random.sample(bot_names, 2)
state['last_bots'] = selected_bots
# Reset other components
return state, [], [], gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True)
global_elo_ratings=None
from datasets import load_dataset,DatasetDict,Dataset
import requests
import os
# Function to get bot response
def format_alpaca_prompt(state):
alpaca_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
alpaca_prompt2 = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
for message in state["history"][0]:
j=""
if message['role']=='user':
j="### Instruction:\n"
else:
j="### Response:\n"
alpaca_prompt += j+ message['content']+"\n\n"
for message in state["history"][1]:
j=""
if message['role']=='user':
j="### Instruction:\n"
else:
j="### Response:\n"
alpaca_prompt2 += j+ message['content']+"\n\n"
return [alpaca_prompt+"### Response:\n",alpaca_prompt2+"### Response:\n"]
import aiohttp
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential
async def get_bot_response(adapter_id, prompt, state, bot_index):
alpaca_prompt = format_alpaca_prompt(state)
print(alpaca_prompt)
payload = {
"inputs": alpaca_prompt[bot_index],
"parameters": {
"adapter_id": adapter_id,
"adapter_source": "hub",
"temperature": 0.7,
"max_new_tokens": 100
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('PREDIBASE_TOKEN')}"
}
async with aiohttp.ClientSession() as session:
try:
async with session.post("https://serving.app.predibase.com/79957f/deployments/v2/llms/mistral-7b/generate",
json=payload, headers=headers, timeout=30) as response:
if response.status == 200:
response_data = await response.json()
response_text = response_data.get('generated_text', '')
else:
print(response.text)
response_text = "Sorry, I couldn't generate a response."
except (aiohttp.ClientError, asyncio.TimeoutError):
response_text = "Sorry, I couldn't generate a response."
return response_text.split('### Instruction')[0]
async def chat_with_bots(user_input, state):
# Use existing bot order from state if available, otherwise shuffle and initialize
if 'last_bots' not in state or not state['last_bots']:
random.shuffle(chatbots)
state['last_bots'] = [chatbots[0], chatbots[1]]
bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1]
bot1_response, bot2_response = await asyncio.gather(
get_bot_response(bot1_adapter, user_input, state, 0),
get_bot_response(bot2_adapter, user_input, state, 1)
)
return bot1_response, bot2_response
def update_ratings(state, winner_index, collection):
elo_ratings = get_user_elo_ratings(collection)
winner = state['last_bots'][winner_index]
loser = state['last_bots'][1 - winner_index]
elo_ratings = update_elo_ratings(elo_ratings, winner, loser)
update_elo_rating(collection, elo_ratings, winner, loser)
return [('Winner: ', winner.replace('rwitz/','').replace('-lora','')), ('Loser: ', loser.replace('rwitz/','').replace('-lora',''))]
def vote_up_model(state, chatbot, chatbot2):
collection = init_database()
update_message = update_ratings(state, 0, collection)
chatbot.append(update_message[0])
chatbot2.append(update_message[1])
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
def vote_down_model(state, chatbot, chatbot2):
collection = init_database()
update_message = update_ratings(state, 1, collection)
chatbot2.append(update_message[0])
chatbot.append(update_message[1])
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
async def user_ask(state, chatbot1, chatbot2, textbox):
global enable_btn
user_input = textbox
if len(user_input) > 200:
user_input = user_input[:200] # Limit user input to 200 characters
collection = init_database() # Initialize the collection object
# Updating state with the current ELO ratings
state["elo_ratings"] = get_user_elo_ratings(collection)
if "history" not in state:
state.update({'history': [[],[]]})
state["history"][0].extend([
{"role": "user", "content": user_input}])
state["history"][1].extend([
{"role": "user", "content": user_input}])
# Chat with bots
bot1_response, bot2_response = await chat_with_bots(user_input, state)
state["history"][0].extend([
{"role": "bot1", "content": bot1_response},
])
state["history"][1].extend([
{"role": "bot2", "content": bot2_response},
])
chatbot1.append((user_input,bot1_response))
chatbot2.append((user_input,bot2_response))
# Keep only the last 10 messages in history
state["history"] = state["history"][-10:]
# Format the conversation in ChatML format
return state, chatbot1, chatbot2, gr.update(value=''),enable_btn,enable_btn
import pandas as pd
# Function to generate leaderboard data
def generate_leaderboard(collection):
rows = list(collection.find())
leaderboard_data = pd.DataFrame(rows, columns=['bot_name', 'elo_rating', 'games_played'])
leaderboard_data.columns = ['Chatbot', 'ELO Score', 'Games Played']
leaderboard_data['ELO Score'] = leaderboard_data['ELO Score'].round().astype(int)
leaderboard_data = leaderboard_data.sort_values('ELO Score', ascending=False)
return leaderboard_data
def refresh_leaderboard():
collection = init_database()
return generate_leaderboard(collection)
async def direct_chat(model, user_input, chatbot):
temp_state = {
"history": [
[{"role": "user", "content": user_input}],
[{"role": "user", "content": user_input}]
]
}
response = await get_bot_response(model, user_input, temp_state, 0)
chatbot.append((user_input, response))
return "", chatbot
# ...
def reset_direct_chat():
return "", [], gr.Dropdown.update(value=model_dropdown.value)
refresh_leaderboard()
# Gradio interface setup
# Gradio interface setup
with gr.Blocks() as demo:
state = gr.State({})
with gr.Tab("πŸ€– Chatbot Arena"):
gr.Markdown("## πŸ₯Š Let's see which chatbot wins!")
with gr.Row():
with gr.Column():
chatbot1 = gr.Chatbot(label='πŸ€– Model A').style(height=500)
upvote_btn_a = gr.Button(value="πŸ‘ Upvote A", interactive=False).style(full_width=True)
with gr.Column():
chatbot2 = gr.Chatbot(label='πŸ€– Model B').style(height=500)
upvote_btn_b = gr.Button(value="πŸ‘ Upvote B", interactive=False).style(full_width=True)
with gr.Row():
with gr.Column(scale=5):
textbox = gr.Textbox(placeholder="🎀 Enter your prompt (up to 200 characters)")
submit_btn = gr.Button(value="Submit")
with gr.Row():
regenerate_btn = gr.Button(value="πŸ”„ Regenerate")
reset_btn = gr.Button(value="πŸ—‘οΈ Reset")
# ...
regenerate_btn.click(regenerate_responses, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2])
reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True)
collection = init_database()
upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True)
with gr.Tab("πŸ’¬ Direct Chat"):
gr.Markdown("## πŸ—£οΈ Chat directly with a model!")
with gr.Row():
model_dropdown = gr.Dropdown(choices=chatbots, value="rwitz/go-bruins-v2-lora", label="πŸ€– Select a model")
with gr.Row():
direct_chatbot = gr.Chatbot(label="πŸ’¬ Direct Chat").style(height=500)
with gr.Row():
with gr.Column(scale=5):
direct_textbox = gr.Textbox(placeholder="πŸ’­ Enter your message")
direct_submit_btn = gr.Button(value="Submit")
with gr.Row():
direct_regenerate_btn = gr.Button(value="πŸ”„ Regenerate")
direct_reset_btn = gr.Button(value="πŸ—‘οΈ Reset Chat")
# ...
direct_regenerate_btn.click(direct_regenerate, inputs=[model_dropdown, direct_textbox, direct_chatbot], outputs=[direct_textbox, direct_chatbot])
direct_textbox.submit(direct_chat, inputs=[model_dropdown, direct_textbox, direct_chatbot], outputs=[direct_textbox, direct_chatbot])
direct_submit_btn.click(direct_chat, inputs=[model_dropdown, direct_textbox, direct_chatbot], outputs=[direct_textbox, direct_chatbot])
direct_reset_btn.click(reset_direct_chat, None, [direct_textbox, direct_chatbot, model_dropdown])
with gr.Tab("πŸ† Leaderboard"):
gr.Markdown("## πŸ“Š Check out the top-performing models!")
try:
leaderboard = gr.Dataframe(refresh_leaderboard())
except:
leaderboard = gr.Dataframe()
with gr.Row():
refresh_btn = gr.Button("πŸ”„ Refresh Leaderboard")
reset_db_btn = gr.Button("πŸ—‘οΈ Reset Database")
reset_message = gr.Textbox()
reset_db_btn.click(reset_database, outputs=[reset_message])
refresh_btn.click(refresh_leaderboard, outputs=[leaderboard])
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch(share=False)