Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import os | |
import pandas as pd | |
import json | |
import ssl | |
import random | |
from elo import update_elo_ratings # Custom function for ELO ratings | |
enable_btn = gr.Button.update(interactive=True) | |
import sqlite3 | |
from pymongo.mongo_client import MongoClient | |
from pymongo.server_api import ServerApi | |
async def direct_regenerate(model, user_input, chatbot, character_name, character_description, user_name): | |
adapter = next(entry['adapter'] for entry in chatbots_data if entry['original_model'] == model) | |
temp_state = { | |
"history": [ | |
[{"role": "user", "content": chatbot[-1][0]}] # Keep the user's last message | |
] | |
} | |
response = await get_bot_response(adapter, user_input, temp_state, 0, character_name, character_description, user_name) | |
chatbot[-1] = (chatbot[-1][0], response) # Update only the assistant's response | |
return "", chatbot | |
password=os.environ.get("MONGODB") | |
def init_database(): | |
uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0" | |
client = MongoClient(uri) | |
db = client["elo_ratings"] | |
collection = db["elo_ratings"] | |
return collection | |
def get_user_elo_ratings(collection): | |
rows = list(collection.find()) | |
if rows: | |
return {row['bot_name']: {'elo_rating': row['elo_rating'], 'games_played': row['games_played']} for row in rows} | |
else: | |
return {"default": {'elo_rating': 1200, 'games_played': 0}} | |
def update_elo_rating(collection, updated_ratings, winner, loser): | |
collection.update_one({"bot_name": winner}, {"$set": {"elo_rating": updated_ratings[winner]['elo_rating'], "games_played": updated_ratings[winner]['games_played']}}, upsert=True) | |
collection.update_one({"bot_name": loser}, {"$set": {"elo_rating": updated_ratings[loser]['elo_rating'], "games_played": updated_ratings[loser]['games_played']}}, upsert=True) | |
import json | |
with open('chatbots.txt', 'r') as file: | |
chatbots_data = json.load(file) | |
chatbots = [entry['adapter'] for entry in chatbots_data] | |
def clear_chat(state): | |
# Reset state including the chatbot order | |
state = {} if state is not None else state | |
# Initialize the collection object | |
collection = init_database() | |
# Get the list of adapter names | |
adapter_names = [entry['adapter'] for entry in chatbots_data] | |
# Randomly select two new adapters | |
selected_adapters = random.sample(adapter_names, 2) | |
state['last_bots'] = selected_adapters | |
# Reset other components specific to the "Chatbot Arena" tab | |
return state, [], [], gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True) | |
from datasets import load_dataset,DatasetDict,Dataset | |
import requests | |
import os | |
# Function to get bot response | |
def format_prompt(state, bot_index, character_name, character_description, user_name, num_messages=5): | |
if character_name is None or character_name.strip() == "": | |
character_name = "Ryan" | |
if character_description is None or character_description.strip() == "": | |
character_description = "Ryan is a college student who is always willing to help. He knows a lot about math and coding." | |
if user_name is None or user_name.strip() == "": | |
user_name = "You" | |
prompt = f"{character_description}\n" | |
# Get the last num_messages messages from the conversation history | |
recent_messages = state["history"][bot_index][-num_messages:] | |
for message in recent_messages: | |
if message['role'] == 'user': | |
prompt += f"{user_name}: {message['content']}\n" | |
else: | |
prompt += f"{character_name}: {message['content']}\n" | |
prompt += f"{character_name}: " | |
print(prompt) | |
return prompt | |
import aiohttp | |
import asyncio | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
async def get_bot_response(adapter_id, prompt, state, bot_index, character_name, character_description, user_name): | |
prompt = format_prompt(state, bot_index, character_name, character_description, user_name) | |
fireworks_adapter_name = next(entry['fireworks_adapter_name'] for entry in chatbots_data if entry['adapter'] == adapter_id) | |
url = "https://api.fireworks.ai/inference/v1/completions" | |
payload = { | |
"model": f"accounts/gaingg19-432d9f/models/{fireworks_adapter_name}", | |
"max_tokens": 250, | |
"temperature": 0.7, | |
"prompt": prompt, | |
"stop": ["<|im_end|>",f"{character_name}:",f"{user_name}:"] | |
} | |
headers = { | |
"Accept": "application/json", | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}" | |
} | |
async with aiohttp.ClientSession() as session: | |
try: | |
async with session.post(url, json=payload, headers=headers, timeout=30) as response: | |
if response.status == 200: | |
response_data = await response.json() | |
response_text = response_data['choices'][0]['text'] | |
else: | |
error_text = await response.text() | |
print(error_text) | |
response_text = "Sorry, I couldn't generate a response." | |
except (aiohttp.ClientError, asyncio.TimeoutError): | |
response_text = "Sorry, I couldn't generate a response." | |
return response_text.strip() | |
async def chat_with_bots(user_input, state, character_name, character_description, user_name): | |
# Use existing bot order from state if available, otherwise shuffle and initialize | |
if 'last_bots' not in state or not state['last_bots']: | |
random.shuffle(chatbots) | |
state['last_bots'] = [chatbots[0], chatbots[1]] | |
bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1] | |
bot1_response, bot2_response = await asyncio.gather( | |
get_bot_response(bot1_adapter, user_input, state, 0, character_name, character_description, user_name), | |
get_bot_response(bot2_adapter, user_input, state, 1, character_name, character_description, user_name) | |
) | |
return bot1_response.replace("<|im_end|>",""), bot2_response.replace("<|im_end|>","") | |
def update_ratings(state, winner_index, collection): | |
elo_ratings = get_user_elo_ratings(collection) | |
winner_adapter = state['last_bots'][winner_index] | |
loser_adapter = state['last_bots'][1 - winner_index] | |
winner = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == winner_adapter) | |
loser = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == loser_adapter) | |
elo_ratings = update_elo_ratings(elo_ratings, winner_adapter, loser_adapter) | |
update_elo_rating(collection, elo_ratings, winner_adapter, loser_adapter) | |
return [('Winner: ', winner), ('Loser: ', loser)] | |
def vote_up_model(state, chatbot, chatbot2): | |
collection = init_database() | |
update_message = update_ratings(state, 0, collection) | |
chatbot.append(update_message[0]) | |
chatbot2.append(update_message[1]) | |
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False) | |
def vote_down_model(state, chatbot, chatbot2): | |
collection = init_database() | |
update_message = update_ratings(state, 1, collection) | |
chatbot2.append(update_message[0]) | |
chatbot.append(update_message[1]) | |
return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False) | |
async def user_ask(state, chatbot1, chatbot2, textbox, character_name, character_description, user_name): | |
if character_name and len(character_name) > 20: | |
character_name = character_name[:20] # Limit character name to 20 characters | |
if character_description and len(character_description) > 500: | |
character_description = character_description[:500] # Limit character description to 200 characters | |
if user_name and len(user_name) > 20: | |
user_name = user_name[:20] # Limit user name to 20 characters | |
global enable_btn | |
user_input = textbox | |
if len(user_input) > 500: | |
user_input = user_input[:500] # Limit user input to 200 characters | |
collection = init_database() # Initialize the collection object | |
# Keep only the last 10 messages in history | |
# Updating state with the current ELO ratings | |
state["elo_ratings"] = get_user_elo_ratings(collection) | |
if "history" not in state: | |
state.update({'history': [[],[]]}) | |
state["history"][0].extend([ | |
{"role": "user", "content": user_input}]) | |
state["history"][1].extend([ | |
{"role": "user", "content": user_input}]) | |
if len(state["history"])>20: | |
state["history"] = state["history"][-20:] | |
# Chat with bots | |
bot1_response, bot2_response = await chat_with_bots(user_input, state, character_name, character_description, user_name) | |
state["history"][0].extend([ | |
{"role": "bot1", "content": bot1_response}, | |
]) | |
state["history"][1].extend([ | |
{"role": "bot2", "content": bot2_response}, | |
]) | |
chatbot1.append((user_input,bot1_response)) | |
chatbot2.append((user_input,bot2_response)) | |
# Keep only the last 10 messages in history | |
# Format the conversation in ChatML format | |
return state, chatbot1, chatbot2, gr.update(value=''),enable_btn,enable_btn | |
import pandas as pd | |
# Function to generate leaderboard data | |
import requests | |
def submit_model(model_name): | |
discord_url = os.environ.get("DISCORD_URL") | |
if discord_url: | |
payload = { | |
"content": f"New model submitted: {model_name}" | |
} | |
response = requests.post(discord_url, json=payload) | |
if response.status_code == 204: | |
return "Model submitted successfully!" | |
else: | |
return "Failed to submit the model." | |
else: | |
return "Discord webhook URL not configured." | |
def generate_leaderboard(collection): | |
rows = list(collection.find()) | |
leaderboard_data = pd.DataFrame(rows, columns=['bot_name', 'elo_rating', 'games_played']) | |
leaderboard_data['original_model'] = leaderboard_data['bot_name'].apply(lambda x: next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == x)) | |
leaderboard_data = leaderboard_data[['original_model', 'elo_rating', 'games_played']] | |
leaderboard_data.columns = ['Chatbot', 'ELO Score', 'Games Played'] | |
leaderboard_data['ELO Score'] = leaderboard_data['ELO Score'].round().astype(int) | |
leaderboard_data = leaderboard_data.sort_values('ELO Score', ascending=False) | |
return leaderboard_data | |
def refresh_leaderboard(): | |
collection = init_database() | |
return generate_leaderboard(collection) | |
async def direct_chat(model, user_input, chatbot, character_name, character_description, user_name): | |
adapter = next(entry['adapter'] for entry in chatbots_data if entry['original_model'] == model) | |
temp_state = { | |
"history": [ | |
[{"role": "user", "content": user_input}], | |
[{"role": "user", "content": user_input}] | |
] | |
} | |
response = await get_bot_response(adapter, user_input, temp_state, 0, character_name, character_description, user_name) | |
chatbot.append((user_input, response)) | |
return "", chatbot | |
def reset_direct_chat(): | |
return [], gr.Textbox.update(value='') | |
refresh_leaderboard() | |
# Gradio interface setup | |
# Gradio interface setup | |
with gr.Blocks() as demo: | |
state = gr.State({}) | |
with gr.Tab("π€ Chatbot Arena"): | |
gr.Markdown("## π₯ Let's see which chatbot wins!") | |
with gr.Row(): | |
with gr.Column(): | |
chatbot1 = gr.Chatbot(label='π€ Model A').style(height=350) | |
upvote_btn_a = gr.Button(value="π Upvote A", interactive=False).style(full_width=True) | |
with gr.Column(): | |
chatbot2 = gr.Chatbot(label='π€ Model B').style(height=350) | |
upvote_btn_b = gr.Button(value="π Upvote B", interactive=False).style(full_width=True) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
textbox = gr.Textbox(placeholder="π€ Enter your prompt (up to 500 characters)") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Row(): | |
reset_btn = gr.Button(value="ποΈ Reset") | |
with gr.Row(): | |
character_name = gr.Textbox(label="Character Name", value="Ryan", placeholder="Enter character name (max 20 chars)") | |
character_description = gr.Textbox(label="Character Description", value="Ryan is a college student who is always willing to help. He knows a lot about math and coding.", placeholder="Enter character description (max 500 chars)") | |
with gr.Row(): | |
user_name = gr.Textbox(label="Your Name", value="You", placeholder="Enter your name (max 20 chars)") | |
# ... | |
reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True) | |
textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True) | |
collection = init_database() | |
upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
with gr.Tab("π¬ Direct Chat"): | |
gr.Markdown("## π£οΈ Chat directly with a model!") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown(choices=[entry['original_model'] for entry in chatbots_data], value=chatbots_data[0]['original_model'], label="π€ Select a model") | |
with gr.Row(): | |
direct_chatbot = gr.Chatbot(label="π¬ Direct Chat").style(height=500) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
direct_textbox = gr.Textbox(placeholder="π Enter your message") | |
direct_submit_btn = gr.Button(value="Submit") | |
with gr.Row(): | |
direct_regenerate_btn = gr.Button(value="π Regenerate") | |
direct_reset_btn = gr.Button(value="ποΈ Reset Chat") | |
# ... | |
direct_regenerate_btn.click(direct_regenerate, inputs=[model_dropdown, direct_textbox, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot]) | |
direct_textbox.submit(direct_chat, inputs=[model_dropdown, direct_textbox, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot]) | |
direct_submit_btn.click(direct_chat, inputs=[model_dropdown, direct_textbox, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot]) | |
direct_reset_btn.click(reset_direct_chat, None, [direct_chatbot, direct_textbox]) | |
with gr.Tab("π Leaderboard"): | |
gr.Markdown("## π Check out the top-performing models!") | |
try: | |
leaderboard = gr.Dataframe(refresh_leaderboard()) | |
except: | |
leaderboard = gr.Dataframe() | |
with gr.Row(): | |
refresh_btn = gr.Button("π Refresh Leaderboard") | |
refresh_btn.click(refresh_leaderboard, outputs=[leaderboard]) | |
with gr.Tab("π¨ Submit Model"): | |
gr.Markdown("## π¨ Submit a new model to be added to the chatbot arena!") | |
with gr.Row(): | |
model_name_input = gr.Textbox(placeholder="Enter the model name") | |
submit_model_btn = gr.Button(value="Submit Model") | |
submit_model_btn.click(submit_model, inputs=[model_name_input], outputs=[model_name_input]) | |
# Launch the Gradio interface | |
if __name__ == "__main__": | |
demo.launch(share=False) | |