Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import os | |
import pandas as pd | |
import json | |
import ssl | |
import random | |
from elo import update_elo_ratings # Custom function for ELO ratings | |
enable_btn = gr.Button.update(interactive=True) | |
import sqlite3 | |
from pymongo.mongo_client import MongoClient | |
from pymongo.server_api import ServerApi | |
async def direct_regenerate(model, user_input, chatbot): | |
temp_state = { | |
"history": [ | |
[{"role": "user", "content": user_input}], | |
[{"role": "user", "content": user_input}] | |
] | |
} | |
response = await get_bot_response(model, user_input, temp_state, 0) | |
chatbot[-1] = (user_input, response) | |
return "", chatbot | |
async def regenerate_responses(state, chatbot1, chatbot2): | |
user_input = state["history"][0][-2]["content"] | |
bot1_response = await get_bot_response(state['last_bots'][0], user_input, state, 0) | |
bot2_response = await get_bot_response(state['last_bots'][1], user_input, state, 1) | |
chatbot1[-1] = (user_input, bot1_response) | |
chatbot2[-1] = (user_input, bot2_response) | |
return chatbot1, chatbot2 | |
password=os.environ.get("MONGODB") | |
def reset_database(): | |
uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0" | |
client = MongoClient(uri) | |
db = client["elo_ratings"] | |
db.drop_collection("elo_ratings") | |
return "Database reset successfully!" | |
def init_database(): | |
uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0" | |
client = MongoClient(uri) | |
db = client["elo_ratings"] | |
collection = db["elo_ratings"] | |
return collection | |
def get_user_elo_ratings(collection): | |
rows = list(collection.find()) | |
if rows: | |
return {row['bot_name']: {'elo_rating': row['elo_rating'], 'games_played': row['games_played']} for row in rows} | |
else: | |
return {"default": {'elo_rating': 1200, 'games_played': 0}} | |
def update_elo_rating(collection, updated_ratings, winner, loser): | |
collection.update_one({"bot_name": winner}, {"$set": {"elo_rating": updated_ratings[winner]['elo_rating'], "games_played": updated_ratings[winner]['games_played']}}, upsert=True) | |
collection.update_one({"bot_name": loser}, {"$set": {"elo_rating": updated_ratings[loser]['elo_rating'], "games_played": updated_ratings[loser]['games_played']}}, upsert=True) | |
# Load chatbot URLs and model names from a JSON file | |
# Load chatbot model adapter names from a text file | |
with open('chatbots.txt', 'r') as file: | |
chatbots = file.read().splitlines() | |
def clear_chat(state): | |
# Reset state including the chatbot order | |
state = {} if state is not None else state | |
# Initialize the collection object | |
collection = init_database() | |
# Get the list of chatbot names | |
bot_names = list(get_user_elo_ratings(collection).keys()) | |
# Randomly select two new Loras | |
selected_bots = random.sample(bot_names, 2) | |
state['last_bots'] = selected_bots | |
# Reset other components | |
return state, [], [], gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True) | |
global_elo_ratings=None | |
from datasets import load_dataset,DatasetDict,Dataset | |
import requests | |
import os | |
# Function to get bot response | |
def format_alpaca_prompt(state): | |
alpaca_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
alpaca_prompt2 = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
for message in state["history"][0]: | |
j="" | |
if message['role']=='user': | |
j="### Instruction:\n" | |
else: | |
j="### Response:\n" | |
alpaca_prompt += j+ message['content']+"\n\n" | |
for message in state["history"][1]: | |
j="" | |
if message['role']=='user': | |
j="### Instruction:\n" | |
else: | |
j="### Response:\n" | |
alpaca_prompt2 += j+ message['content']+"\n\n" | |
return [alpaca_prompt+"### Response:\n",alpaca_prompt2+"### Response:\n"] | |
import aiohttp | |
import asyncio | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
async def get_bot_response(adapter_id, prompt, state, bot_index): | |
alpaca_prompt = format_alpaca_prompt(state) | |
print(alpaca_prompt) | |
payload = { | |
"inputs": alpaca_prompt[bot_index], | |
"parameters": { | |
"adapter_id": adapter_id, | |
"adapter_source": "hub", | |
"temperature": 0.7, | |
"max_new_tokens": 100 | |
} | |
} | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {os.environ.get('PREDIBASE_TOKEN')}" | |
} | |
async with aiohttp.ClientSession() as session: | |
try: | |
async with session.post("https://serving.app.predibase.com/79957f/deployments/v2/llms/mistral-7b/generate", | |
json=payload, headers=headers, timeout=30) as response: | |
if response.status == 200: | |
response_data = await response.json() | |
response_text = response_data.get('generated_text', '') | |
else: | |
print(response.text) | |
response_text = "Sorry, I couldn't generate a response." | |
except (aiohttp.ClientError, asyncio.TimeoutError): | |
response_text = "Sorry, I couldn't generate a response." | |
return response_text.split('### Instruction')[0] | |
async def chat_with_bots(user_input, state): | |
# Use existing bot order from state if available, otherwise shuffle and initialize | |
if 'last_bots' not in state or not state['last_bots']: | |
random.shuffle(chatbots) | |
state['last_bots'] = [chatbots[0], chatbots[1]] | |
bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1] | |
bot1_response, bot2_response = await asyncio.gather( | |
get_bot_response(bot1_adapter, user_input, state, 0), | |
get_bot_response(bot2_adapter, user_input, state, 1) | |
) | |
return bot1_response, bot2_response | |
def update_ratings(state, winner_index, collection): | |
elo_ratings = get_user_elo_ratings(collection) | |
winner = state['last_bots'][winner_index] | |
loser = state['last_bots'][1 - winner_index] | |
elo_ratings = update_elo_ratings(elo_ratings, winner, loser) | |
update_elo_rating(collection, elo_ratings, winner, loser) | |
return [('Winner: ', winner.replace('rwitz/','').replace('-lora','')), ('Loser: ', loser.replace('rwitz/','').replace('-lora',''))] | |
def vote_up_model(state, chatbot, chatbot2): | |
collection = init_database() | |
update_message = update_ratings(state, 0, collection) | |
chatbot.append(update_message[0]) | |
chatbot2.append(update_message[1]) | |
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False) | |
def vote_down_model(state, chatbot, chatbot2): | |
collection = init_database() | |
update_message = update_ratings(state, 1, collection) | |
chatbot2.append(update_message[0]) | |
chatbot.append(update_message[1]) | |
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False) | |
async def user_ask(state, chatbot1, chatbot2, textbox): | |
global enable_btn | |
user_input = textbox | |
if len(user_input) > 200: | |
user_input = user_input[:200] # Limit user input to 200 characters | |
collection = init_database() # Initialize the collection object | |
# Updating state with the current ELO ratings | |
state["elo_ratings"] = get_user_elo_ratings(collection) | |
if "history" not in state: | |
state.update({'history': [[],[]]}) | |
state["history"][0].extend([ | |
{"role": "user", "content": user_input}]) | |
state["history"][1].extend([ | |
{"role": "user", "content": user_input}]) | |
# Chat with bots | |
bot1_response, bot2_response = await chat_with_bots(user_input, state) | |
state["history"][0].extend([ | |
{"role": "bot1", "content": bot1_response}, | |
]) | |
state["history"][1].extend([ | |
{"role": "bot2", "content": bot2_response}, | |
]) | |
chatbot1.append((user_input,bot1_response)) | |
chatbot2.append((user_input,bot2_response)) | |
# Keep only the last 10 messages in history | |
state["history"] = state["history"][-10:] | |
# Format the conversation in ChatML format | |
return state, chatbot1, chatbot2, gr.update(value=''),enable_btn,enable_btn | |
import pandas as pd | |
# Function to generate leaderboard data | |
def generate_leaderboard(collection): | |
rows = list(collection.find()) | |
leaderboard_data = pd.DataFrame(rows, columns=['bot_name', 'elo_rating', 'games_played']) | |
leaderboard_data.columns = ['Chatbot', 'ELO Score', 'Games Played'] | |
leaderboard_data['ELO Score'] = leaderboard_data['ELO Score'].round().astype(int) | |
leaderboard_data = leaderboard_data.sort_values('ELO Score', ascending=False) | |
return leaderboard_data | |
def refresh_leaderboard(): | |
collection = init_database() | |
return generate_leaderboard(collection) | |
async def direct_chat(model, user_input, chatbot): | |
temp_state = { | |
"history": [ | |
[{"role": "user", "content": user_input}], | |
[{"role": "user", "content": user_input}] | |
] | |
} | |
response = await get_bot_response(model, user_input, temp_state, 0) | |
chatbot.append((user_input, response)) | |
return "", chatbot | |
# ... | |
def reset_direct_chat(): | |
return "", [], gr.Dropdown.update(value=model_dropdown.value) | |
refresh_leaderboard() | |
# Gradio interface setup | |
# Gradio interface setup | |
with gr.Blocks() as demo: | |
state = gr.State({}) | |
with gr.Tab("π€ Chatbot Arena"): | |
gr.Markdown("## π₯ Let's see which chatbot wins!") | |
with gr.Row(): | |
with gr.Column(): | |
chatbot1 = gr.Chatbot(label='π€ Model A').style(height=500) | |
upvote_btn_a = gr.Button(value="π Upvote A", interactive=False).style(full_width=True) | |
with gr.Column(): | |
chatbot2 = gr.Chatbot(label='π€ Model B').style(height=500) | |
upvote_btn_b = gr.Button(value="π Upvote B", interactive=False).style(full_width=True) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
textbox = gr.Textbox(placeholder="π€ Enter your prompt (up to 200 characters)") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Row(): | |
regenerate_btn = gr.Button(value="π Regenerate") | |
reset_btn = gr.Button(value="ποΈ Reset") | |
# ... | |
regenerate_btn.click(regenerate_responses, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2]) | |
reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True) | |
collection = init_database() | |
upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True) | |
with gr.Tab("π¬ Direct Chat"): | |
gr.Markdown("## π£οΈ Chat directly with a model!") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown(choices=chatbots, value="rwitz/go-bruins-v2-lora", label="π€ Select a model") | |
with gr.Row(): | |
direct_chatbot = gr.Chatbot(label="π¬ Direct Chat").style(height=500) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
direct_textbox = gr.Textbox(placeholder="π Enter your message") | |
direct_submit_btn = gr.Button(value="Submit") | |
with gr.Row(): | |
direct_regenerate_btn = gr.Button(value="π Regenerate") | |
direct_reset_btn = gr.Button(value="ποΈ Reset Chat") | |
# ... | |
direct_regenerate_btn.click(direct_regenerate, inputs=[model_dropdown, direct_textbox, direct_chatbot], outputs=[direct_textbox, direct_chatbot]) | |
direct_textbox.submit(direct_chat, inputs=[model_dropdown, direct_textbox, direct_chatbot], outputs=[direct_textbox, direct_chatbot]) | |
direct_submit_btn.click(direct_chat, inputs=[model_dropdown, direct_textbox, direct_chatbot], outputs=[direct_textbox, direct_chatbot]) | |
direct_reset_btn.click(reset_direct_chat, None, [direct_textbox, direct_chatbot, model_dropdown]) | |
with gr.Tab("π Leaderboard"): | |
gr.Markdown("## π Check out the top-performing models!") | |
try: | |
leaderboard = gr.Dataframe(refresh_leaderboard()) | |
except: | |
leaderboard = gr.Dataframe() | |
with gr.Row(): | |
refresh_btn = gr.Button("π Refresh Leaderboard") | |
reset_db_btn = gr.Button("ποΈ Reset Database") | |
reset_message = gr.Textbox() | |
reset_db_btn.click(reset_database, outputs=[reset_message]) | |
refresh_btn.click(refresh_leaderboard, outputs=[leaderboard]) | |
# Launch the Gradio interface | |
if __name__ == "__main__": | |
demo.launch(share=False) |