versae commited on
Commit
bed30b4
1 Parent(s): 15c5469

Fixes chat

Browse files
Files changed (1) hide show
  1. gradio_app.py +6 -5
gradio_app.py CHANGED
@@ -170,10 +170,10 @@ class TextGeneration:
170
  input_text,
171
  **generation_kwargs,
172
  )[0]["generated_text"]
173
- if generation_kwargs["do_clean"]:
174
- generated_text = cleaner.clean_txt(generated_text)
175
  if generated_text.strip().startswith(input_text):
176
  generated_text = generated_text.replace(input_text, "", 1).strip()
 
 
177
  if generated_text:
178
  if previous_text and previous_text != text:
179
  diff = [
@@ -258,9 +258,10 @@ def chat_with_gpt(agent, user, context, user_message, history, max_length, top_k
258
  break
259
  context += history_context
260
  for _ in range(5):
261
- response = generator.generate(f"{context}\n\n{user}: {message}.\n", generation_kwargs)[0]
 
262
  if DEBUG:
263
- print("\n-----" + response + "-----\n")
264
  # response = response.split("\n")[-1]
265
  # if agent in response and response.split(agent)[-1]:
266
  # response = response.split(agent)[-1]
@@ -268,7 +269,7 @@ def chat_with_gpt(agent, user, context, user_message, history, max_length, top_k
268
  # response = response.split(user)[-1]
269
  # Take the first response
270
  response = [
271
- r for r in response.split(f"{AGENT}:") if r.strip()
272
  ][0].split(USER)[0].replace(f"{AGENT}:", "\n").strip()
273
  if response[0] in string.punctuation:
274
  response = response[1:].strip()
 
170
  input_text,
171
  **generation_kwargs,
172
  )[0]["generated_text"]
 
 
173
  if generated_text.strip().startswith(input_text):
174
  generated_text = generated_text.replace(input_text, "", 1).strip()
175
+ if generation_kwargs["do_clean"]:
176
+ generated_text = cleaner.clean_txt(generated_text)
177
  if generated_text:
178
  if previous_text and previous_text != text:
179
  diff = [
 
258
  break
259
  context += history_context
260
  for _ in range(5):
261
+ prompt = f"{context}\n\n{user}: {message}.\n"
262
+ response = generator.generate(prompt, generation_kwargs)[0]
263
  if DEBUG:
264
+ print("\n-----\n" + response + "\n-----\n")
265
  # response = response.split("\n")[-1]
266
  # if agent in response and response.split(agent)[-1]:
267
  # response = response.split(agent)[-1]
 
269
  # response = response.split(user)[-1]
270
  # Take the first response
271
  response = [
272
+ r for r in response.replace(prompt, "").split(f"{AGENT}:") if r.strip()
273
  ][0].split(USER)[0].replace(f"{AGENT}:", "\n").strip()
274
  if response[0] in string.punctuation:
275
  response = response[1:].strip()