rwitz commited on
Commit
5425a03
β€’
1 Parent(s): c872b22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -121
app.py CHANGED
@@ -9,31 +9,7 @@ from elo import update_elo_ratings # Custom function for ELO ratings
9
  enable_btn = gr.Button.update(interactive=True)
10
 
11
  import sqlite3
12
- import requests
13
 
14
- async def classify_vote(user_input):
15
- url = "https://api-inference.huggingface.co/models/facebook/bart-large-mnli"
16
- headers = {
17
- "accept": "*/*",
18
- "accept-language": "en-US,en;q=0.9",
19
- "content-type": "application/json",
20
- }
21
- payload = {
22
- "inputs": user_input,
23
- "parameters": {
24
- "candidate_labels": "character roleplay,small talk conversation,mathematics calculations and logic,creative writing,factual knowledge",
25
- }
26
- }
27
- async with aiohttp.ClientSession() as session:
28
- async with session.post(url, headers=headers, json=payload) as response:
29
- if response.status == 200:
30
- response_data = await response.json()
31
- top_category = response_data["labels"][0].strip()
32
- return top_category
33
- else:
34
- print(f"Error: {response.status}")
35
- return None
36
-
37
  from pymongo.mongo_client import MongoClient
38
  from pymongo.server_api import ServerApi
39
  async def direct_regenerate(model, user_input, chatbot, character_name, character_description, user_name):
@@ -51,10 +27,21 @@ password=os.environ.get("MONGODB")
51
  def init_database():
52
  uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
53
  client = MongoClient(uri)
54
- db = client["elo_ratings6"]
55
  collection = db["elo_ratings"]
56
  return collection
57
 
 
 
 
 
 
 
 
 
 
 
 
58
  import json
59
 
60
  with open('chatbots.txt', 'r') as file:
@@ -76,7 +63,8 @@ def clear_chat(state):
76
  state['last_bots'] = selected_adapters
77
 
78
  # Reset other components specific to the "Chatbot Arena" tab
79
- return state, [], [], gr.Button.update(interactive=False, value="πŸ‘ Upvote A"), gr.Button.update(interactive=False, value="πŸ‘ Upvote B"), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True)
 
80
  from datasets import load_dataset,DatasetDict,Dataset
81
  import requests
82
  import os
@@ -153,56 +141,32 @@ async def chat_with_bots(user_input, state, character_name, character_descriptio
153
  get_bot_response(bot2_adapter, user_input, state, 1, character_name, character_description, user_name)
154
  )
155
  return bot1_response.replace("<|im_end|>",""), bot2_response.replace("<|im_end|>","")
156
- def update_ratings(state, winner_index, collection, category):
157
- try:
158
- elo_ratings = get_user_elo_ratings(collection)
159
- winner_adapter = state['last_bots'][winner_index]
160
- loser_adapter = state['last_bots'][1 - winner_index]
161
-
162
- winner = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == winner_adapter)
163
- loser = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == loser_adapter)
164
-
165
- if category != "overall":
166
- elo_ratings = update_elo_ratings(elo_ratings, winner_adapter, loser_adapter, category)
167
- update_elo_rating(collection, elo_ratings, winner_adapter, loser_adapter, category)
168
-
169
- return [('Winner: ', winner), ('Loser: ', loser)]
170
- except Exception as e:
171
- print(f"Error updating ratings: {str(e)}")
172
- return [('Error', 'Failed to update ratings.'), ('Error', 'Failed to update ratings.')]
173
- async def vote_up_model(state, chatbot, chatbot2, character_name, character_description, user_name):
174
- user_input = format_prompt(state, 0, character_name, character_description, user_name)
175
- collection = init_database()
176
 
177
- # Disable both upvote buttons immediately
178
- yield chatbot, chatbot2, gr.Button.update(interactive=False, value="πŸ‘ Upvoted"), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
179
 
180
- # Update ratings and yield winner/loser immediately
181
- update_message = update_ratings(state, 0, collection, "overall")
182
- chatbot.append(update_message[0])
183
- chatbot2.append(update_message[1])
184
- # Process sentiment analysis asynchronously
185
- top_category = await classify_vote(user_input)
186
 
187
- if top_category:
188
- update_ratings(state, 0, collection, top_category)
189
- async def vote_down_model(state, chatbot, chatbot2, character_name, character_description, user_name):
190
- user_input = format_prompt(state, 1, character_name, character_description, user_name)
191
  collection = init_database()
192
-
193
- # Disable both upvote buttons immediately
194
- yield chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False, value="πŸ‘ Upvoted"), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
195
-
196
- # Update ratings and yield winner/loser immediately
197
- update_message = update_ratings(state, 1, collection, "overall")
 
 
198
  chatbot2.append(update_message[0])
199
  chatbot.append(update_message[1])
200
-
201
- # Process sentiment analysis asynchronously
202
- top_category = await classify_vote(user_input)
203
-
204
- if top_category:
205
- update_ratings(state, 1, collection, top_category)
206
  async def user_ask(state, chatbot1, chatbot2, textbox, character_name, character_description, user_name):
207
  if character_name and len(character_name) > 20:
208
  character_name = character_name[:20] # Limit character name to 20 characters
@@ -265,55 +229,16 @@ def submit_model(model_name):
265
  else:
266
  return "Discord webhook URL not configured."
267
 
268
- def get_user_elo_ratings(collection):
269
- rows = list(collection.find())
270
- if rows:
271
- elo_ratings = {}
272
- for row in rows:
273
- bot_name = row['bot_name']
274
- if bot_name not in elo_ratings:
275
- elo_ratings[bot_name] = {}
276
- for category in row['categories']:
277
- elo_ratings[bot_name][category] = {'elo_rating': row['categories'][category]['elo_rating'], 'games_played': row['categories'][category]['games_played']}
278
- return elo_ratings
279
- else:
280
- return {"default": {'overall': {'elo_rating': 1200, 'games_played': 0}}}
281
-
282
- def update_elo_rating(collection, updated_ratings, winner, loser, category):
283
- collection.update_one({"bot_name": winner}, {"$set": {f"categories.{category}.elo_rating": updated_ratings[winner][category]['elo_rating'], f"categories.{category}.games_played": updated_ratings[winner][category]['games_played']}}, upsert=True)
284
- collection.update_one({"bot_name": loser}, {"$set": {f"categories.{category}.elo_rating": updated_ratings[loser][category]['elo_rating'], f"categories.{category}.games_played": updated_ratings[loser][category]['games_played']}}, upsert=True)
285
-
286
  def generate_leaderboard(collection):
287
  rows = list(collection.find())
288
- categories = set()
289
- for row in rows:
290
- categories.update(row['categories'].keys())
291
-
292
- leaderboard_data = []
293
- for row in rows:
294
- bot_name = row['bot_name']
295
- original_model = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == bot_name)
296
-
297
- category_data = {category: row['categories'].get(category, {}).get('elo_rating', 1200) for category in categories}
298
-
299
- total_elo = sum(data['elo_rating'] for data in row['categories'].values() if 'elo_rating' in data)
300
- total_games = sum(data['games_played'] for data in row['categories'].values() if 'games_played' in data)
301
- avg_elo = total_elo / len(row['categories']) if len(row['categories']) > 0 else 0
302
-
303
- leaderboard_data.append([original_model, avg_elo, total_games] + [round(category_data[category],0) for category in categories])
304
-
305
- columns = ['Chatbot', 'Avg ELO Score', 'Total Games Played'] + list(categories)
306
-
307
- leaderboard_data = pd.DataFrame(leaderboard_data, columns=columns)
308
- leaderboard_data['Avg ELO Score'] = leaderboard_data['Avg ELO Score'].round().astype(int)
309
-
310
- # Replace '-' with empty cells
311
- leaderboard_data = leaderboard_data.replace('-', '')
312
-
313
- # Sort by average ELO score in descending order
314
- leaderboard_data = leaderboard_data.sort_values('Avg ELO Score', ascending=False)
315
-
316
  return leaderboard_data
 
317
  def refresh_leaderboard():
318
  collection = init_database()
319
  leaderboard_data = generate_leaderboard(collection)
@@ -378,13 +303,13 @@ with gr.Blocks() as demo:
378
  # ...
379
 
380
  reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
381
- submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=False)
382
- textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=False)
383
  collection = init_database()
384
 
385
- upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2, character_name, character_description, user_name], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn], api_name="upvote_a")
386
- upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2, character_name, character_description, user_name], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn], api_name="upvote_b")
387
- # ...
388
  with gr.Tab("πŸ’¬ Direct Chat"):
389
  gr.Markdown("## πŸ—£οΈ Chat directly with a model!")
390
 
 
9
  enable_btn = gr.Button.update(interactive=True)
10
 
11
  import sqlite3
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from pymongo.mongo_client import MongoClient
14
  from pymongo.server_api import ServerApi
15
  async def direct_regenerate(model, user_input, chatbot, character_name, character_description, user_name):
 
27
  def init_database():
28
  uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
29
  client = MongoClient(uri)
30
+ db = client["elo_ratings"]
31
  collection = db["elo_ratings"]
32
  return collection
33
 
34
+ def get_user_elo_ratings(collection):
35
+ rows = list(collection.find())
36
+ if rows:
37
+ return {row['bot_name']: {'elo_rating': row['elo_rating'], 'games_played': row['games_played']} for row in rows}
38
+ else:
39
+ return {"default": {'elo_rating': 1200, 'games_played': 0}}
40
+
41
+ def update_elo_rating(collection, updated_ratings, winner, loser):
42
+ collection.update_one({"bot_name": winner}, {"$set": {"elo_rating": updated_ratings[winner]['elo_rating'], "games_played": updated_ratings[winner]['games_played']}}, upsert=True)
43
+ collection.update_one({"bot_name": loser}, {"$set": {"elo_rating": updated_ratings[loser]['elo_rating'], "games_played": updated_ratings[loser]['games_played']}}, upsert=True)
44
+
45
  import json
46
 
47
  with open('chatbots.txt', 'r') as file:
 
63
  state['last_bots'] = selected_adapters
64
 
65
  # Reset other components specific to the "Chatbot Arena" tab
66
+ return state, [], [], gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True)
67
+
68
  from datasets import load_dataset,DatasetDict,Dataset
69
  import requests
70
  import os
 
141
  get_bot_response(bot2_adapter, user_input, state, 1, character_name, character_description, user_name)
142
  )
143
  return bot1_response.replace("<|im_end|>",""), bot2_response.replace("<|im_end|>","")
144
+ def update_ratings(state, winner_index, collection):
145
+ elo_ratings = get_user_elo_ratings(collection)
146
+ winner_adapter = state['last_bots'][winner_index]
147
+ loser_adapter = state['last_bots'][1 - winner_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ winner = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == winner_adapter)
150
+ loser = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == loser_adapter)
151
 
152
+ elo_ratings = update_elo_ratings(elo_ratings, winner_adapter, loser_adapter)
153
+ update_elo_rating(collection, elo_ratings, winner_adapter, loser_adapter)
154
+ return [('Winner: ', winner), ('Loser: ', loser)]
 
 
 
155
 
156
+ def vote_up_model(state, chatbot, chatbot2):
 
 
 
157
  collection = init_database()
158
+ update_message = update_ratings(state, 0, collection)
159
+ chatbot.append(update_message[0])
160
+ chatbot2.append(update_message[1])
161
+ return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
162
+
163
+ def vote_down_model(state, chatbot, chatbot2):
164
+ collection = init_database()
165
+ update_message = update_ratings(state, 1, collection)
166
  chatbot2.append(update_message[0])
167
  chatbot.append(update_message[1])
168
+ return chatbot, chatbot2, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=False), gr.Button.update(interactive=False)
169
+
 
 
 
 
170
  async def user_ask(state, chatbot1, chatbot2, textbox, character_name, character_description, user_name):
171
  if character_name and len(character_name) > 20:
172
  character_name = character_name[:20] # Limit character name to 20 characters
 
229
  else:
230
  return "Discord webhook URL not configured."
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def generate_leaderboard(collection):
233
  rows = list(collection.find())
234
+ leaderboard_data = pd.DataFrame(rows, columns=['bot_name', 'elo_rating', 'games_played'])
235
+ leaderboard_data['original_model'] = leaderboard_data['bot_name'].apply(lambda x: next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == x))
236
+ leaderboard_data = leaderboard_data[['original_model', 'elo_rating', 'games_played']]
237
+ leaderboard_data.columns = ['Chatbot', 'ELO Score', 'Games Played']
238
+ leaderboard_data['ELO Score'] = leaderboard_data['ELO Score'].round().astype(int)
239
+ leaderboard_data = leaderboard_data.sort_values('ELO Score', ascending=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  return leaderboard_data
241
+
242
  def refresh_leaderboard():
243
  collection = init_database()
244
  leaderboard_data = generate_leaderboard(collection)
 
303
  # ...
304
 
305
  reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
306
+ submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True)
307
+ textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True)
308
  collection = init_database()
309
 
310
+ upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
311
+ upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
312
+
313
  with gr.Tab("πŸ’¬ Direct Chat"):
314
  gr.Markdown("## πŸ—£οΈ Chat directly with a model!")
315