rwitz commited on
Commit
5170342
1 Parent(s): 3c23026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -27
app.py CHANGED
@@ -72,30 +72,22 @@ import os
72
 
73
 
74
  # Function to get bot response
75
- def format_chatml_prompt(state):
76
- chatml_prompt = "<|im_start|>system You are a helpful assistant, who can think outside the box.<|im_end|>"
77
- for message in state["history"][0]:
78
  if message['role'] == 'user':
79
- chatml_prompt += "\n<|im_start|>user " + message['content'] + "<|im_end|>"
80
  else:
81
- chatml_prompt += "\n<|im_start|>assistant " + message['content'] + "<|im_end|>"
82
-
83
- if len(state["history"]) > 1:
84
- chatml_prompt2 = "<|im_start|>system You are a helpful assistant, who can think outside the box.<|im_end|>"
85
- for message in state["history"][1]:
86
- if message['role'] == 'user':
87
- chatml_prompt2 += "\n<|im_start|>user " + message['content'] + "<|im_end|>"
88
- else:
89
- chatml_prompt2 += "\n<|im_start|>assistant " + message['content'] + "<|im_end|>"
90
- return [chatml_prompt + "\n<|im_start|>assistant", chatml_prompt2 + "\n<|im_start|>assistant"]
91
- else:
92
- return [chatml_prompt + "\n<|im_start|>assistant"]
93
  import aiohttp
94
  import asyncio
95
  from tenacity import retry, stop_after_attempt, wait_exponential
96
 
97
- async def get_bot_response(adapter_id, prompt, state, bot_index):
98
- chatml_prompt = format_chatml_prompt(state)
 
99
  fireworks_adapter_name = next(entry['fireworks_adapter_name'] for entry in chatbots_data if entry['adapter'] == adapter_id)
100
 
101
  url = "https://api.fireworks.ai/inference/v1/completions"
@@ -107,8 +99,8 @@ async def get_bot_response(adapter_id, prompt, state, bot_index):
107
  "presence_penalty": 0,
108
  "frequency_penalty": 0,
109
  "temperature": 0.7,
110
- "prompt": chatml_prompt[bot_index],
111
- "stop": ["<|im_end|>"]
112
  }
113
  headers = {
114
  "Accept": "application/json",
@@ -138,10 +130,9 @@ async def chat_with_bots(user_input, state):
138
  bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1]
139
 
140
  bot1_response, bot2_response = await asyncio.gather(
141
- get_bot_response(bot1_adapter, user_input, state, 0),
142
- get_bot_response(bot2_adapter, user_input, state, 1)
143
  )
144
-
145
  return bot1_response.replace("<|im_end|>",""), bot2_response.replace("<|im_end|>","")
146
  def update_ratings(state, winner_index, collection):
147
  elo_ratings = get_user_elo_ratings(collection)
@@ -186,7 +177,8 @@ async def user_ask(state, chatbot1, chatbot2, textbox):
186
  state["history"][1].extend([
187
  {"role": "user", "content": user_input}])
188
  # Chat with bots
189
- bot1_response, bot2_response = await chat_with_bots(user_input, state)
 
190
  state["history"][0].extend([
191
  {"role": "bot1", "content": bot1_response},
192
  ])
@@ -273,13 +265,15 @@ with gr.Blocks() as demo:
273
 
274
  with gr.Row():
275
  reset_btn = gr.Button(value="🗑️ Reset")
 
 
 
 
276
 
277
  # ...
278
 
279
  reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
280
- submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True)
281
-
282
- collection = init_database()
283
 
284
  upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
285
  upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
 
72
 
73
 
74
  # Function to get bot response
75
+ def format_prompt(state, bot_index, character_name, character_description):
76
+ prompt = f"{character_description}\n\n"
77
+ for message in state["history"][bot_index]:
78
  if message['role'] == 'user':
79
+ prompt += f"You: {message['content']}\n"
80
  else:
81
+ prompt += f"{character_name}: {message['content']}\n"
82
+ prompt += f"{character_name}: "
83
+ return prompt
 
 
 
 
 
 
 
 
 
84
  import aiohttp
85
  import asyncio
86
  from tenacity import retry, stop_after_attempt, wait_exponential
87
 
88
+ async def get_bot_response(adapter_id, prompt, state, bot_index, character_name, character_description):
89
+ prompt = format_prompt(state, bot_index, character_name, character_description)
90
+
91
  fireworks_adapter_name = next(entry['fireworks_adapter_name'] for entry in chatbots_data if entry['adapter'] == adapter_id)
92
 
93
  url = "https://api.fireworks.ai/inference/v1/completions"
 
99
  "presence_penalty": 0,
100
  "frequency_penalty": 0,
101
  "temperature": 0.7,
102
+ "prompt": prompt,
103
+ "stop": ["<|im_end|>","\n"]
104
  }
105
  headers = {
106
  "Accept": "application/json",
 
130
  bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1]
131
 
132
  bot1_response, bot2_response = await asyncio.gather(
133
+ get_bot_response(bot1_adapter, user_input, state, 0, character_name, character_description),
134
+ get_bot_response(bot2_adapter, user_input, state, 1, character_name, character_description)
135
  )
 
136
  return bot1_response.replace("<|im_end|>",""), bot2_response.replace("<|im_end|>","")
137
  def update_ratings(state, winner_index, collection):
138
  elo_ratings = get_user_elo_ratings(collection)
 
177
  state["history"][1].extend([
178
  {"role": "user", "content": user_input}])
179
  # Chat with bots
180
+ bot1_response, bot2_response = await chat_with_bots(user_input, state, character_name, character_description)
181
+
182
  state["history"][0].extend([
183
  {"role": "bot1", "content": bot1_response},
184
  ])
 
265
 
266
  with gr.Row():
267
  reset_btn = gr.Button(value="🗑️ Reset")
268
+ with gr.Row():
269
+ character_name = gr.Textbox(value="Ryan", placeholder="Enter character name (max 20 chars)", max_lines=1)
270
+ character_description = gr.Textbox(value="Ryan is a college student who is always willing to help. He knows a lot about math and coding.", placeholder="Enter character description (max 200 chars)", max_lines=5)
271
+
272
 
273
  # ...
274
 
275
  reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
276
+ submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True) collection = init_database()
 
 
277
 
278
  upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])
279
  upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn])