versae commited on
Commit
95792cc
1 Parent(s): 4b4ebad

Add a chat tab

Browse files
Files changed (1) hide show
  1. gradio_app.py +106 -21
gradio_app.py CHANGED
@@ -1,5 +1,6 @@
1
- import random
2
  import os
 
 
3
 
4
  import gradio as gr
5
  import torch
@@ -9,6 +10,7 @@ import logging
9
  logger = logging.getLogger()
10
  logger.addHandler(logging.StreamHandler())
11
 
 
12
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
13
  DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
14
  if DEVICE != "cpu" and not torch.cuda.is_available():
@@ -60,7 +62,10 @@ Este modelo ha sido entrenado con [Mesh Transformer JAX](https://github.com/king
60
  """
61
 
62
  FOOTER = """
63
- Para más información, visite el [repositorio del modelo](https://huggingface.co/bertin-project/bertin-gpt-j-6B).
 
 
 
64
  """.strip()
65
 
66
  EXAMPLES = [
@@ -74,6 +79,13 @@ Pregunta: ¿Quién cuidaba del hogar los dioses?
74
  Respuesta:""",
75
  ]
76
 
 
 
 
 
 
 
 
77
  class Normalizer:
78
  def remove_repetitions(self, text):
79
  """Remove repetitions"""
@@ -124,8 +136,6 @@ class TextGeneration:
124
  def generate(self, text, generation_kwargs):
125
  max_length = len(self.tokenizer(text)["input_ids"]) + generation_kwargs["max_length"]
126
  generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
127
- # generation_kwargs["num_return_sequences"] = 1
128
- # generation_kwargs["return_full_text"] = False
129
  generated_text = None
130
  if text:
131
  for _ in range(10):
@@ -196,6 +206,64 @@ def expand_with_gpt(hidden, text, max_length, top_k, top_p, temperature, do_samp
196
  }
197
  return generator.generate(hidden or text, generation_kwargs)
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  with gr.Blocks() as demo:
200
  gr.Markdown(HEADER)
201
  with gr.Row():
@@ -247,24 +315,41 @@ with gr.Blocks() as demo:
247
  # help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.",
248
  )
249
  with gr.Column():
250
- textbox = gr.Textbox(label="Texto", placeholder="Escriba algo (o seleccione un ejemplo) y pulse 'Generar'...", lines=8)
251
- examples = gr.Dropdown(label="Ejemplos", choices=EXAMPLES, value=None, type="value")
252
- hidden = gr.Textbox(visible=False, show_label=False)
253
- with gr.Box():
254
- # output = gr.Markdown()
255
- output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"})
256
- with gr.Row():
257
- btn = gr.Button("Generar")
258
- btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output])
259
- expand_btn = gr.Button("Añadir")
260
- expand_btn.click(expand_with_gpt, inputs=[hidden, textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output])
261
-
262
- edit_btn = gr.Button("Editar", variant="secondary")
263
- edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output])
264
- clean_btn = gr.Button("Borrar", variant="secondary")
265
- clean_btn.click(lambda: ("", "", [], ""), inputs=[], outputs=[textbox, hidden, output, examples])
266
- examples.change(lambda x: x, inputs=[examples], outputs=[textbox])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  gr.Markdown(FOOTER)
268
 
 
 
269
  demo.launch()
270
  # gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch()
 
 
1
  import os
2
+ import random
3
+ import string
4
 
5
  import gradio as gr
6
  import torch
 
10
  logger = logging.getLogger()
11
  logger.addHandler(logging.StreamHandler())
12
 
13
+ DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1"
14
  HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
15
  DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
16
  if DEVICE != "cpu" and not torch.cuda.is_available():
 
62
  """
63
 
64
  FOOTER = """
65
+ <div align=center>
66
+ Para más información, visite el repositorio del modelo: <a href="https://huggingface.co/bertin-project/bertin-gpt-j-6B">BERTIN-GPT-J-6B</a>.
67
+ <img src="https://visitor-badge.glitch.me/badge?page_id=spaces/bertin-project/bertin-gpt-j-6B"/>
68
+ <div align=center>
69
  """.strip()
70
 
71
  EXAMPLES = [
 
79
  Respuesta:""",
80
  ]
81
 
82
+ AGENT = "BERTIN"
83
+ USER = "ENTREVISTADOR"
84
+ CONTEXT = """La siguiente conversación es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisión Española:
85
+
86
+ {USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros.
87
+ {AGENT}: Gracias. El placer es mío."""
88
+
89
  class Normalizer:
90
  def remove_repetitions(self, text):
91
  """Remove repetitions"""
 
136
  def generate(self, text, generation_kwargs):
137
  max_length = len(self.tokenizer(text)["input_ids"]) + generation_kwargs["max_length"]
138
  generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
 
 
139
  generated_text = None
140
  if text:
141
  for _ in range(10):
 
206
  }
207
  return generator.generate(hidden or text, generation_kwargs)
208
 
209
+ def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean):
210
+ # agent = AGENT
211
+ # user = USER
212
+ generation_kwargs = {
213
+ "max_length": 25,
214
+ "top_k": top_k,
215
+ "top_p": top_p,
216
+ "temperature": temperature,
217
+ "do_sample": do_sample,
218
+ "do_clean": do_clean,
219
+ # "num_return_sequences": 1,
220
+ # "return_full_text": False,
221
+ }
222
+ message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1]
223
+ history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")]
224
+ context = context.format(USER=user or USER, AGENT=agent or AGENT).strip()
225
+ if context[-1] not in ".:":
226
+ context += "."
227
+ context_length = len(context.split())
228
+ history_take = 0
229
+ history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
230
+ while len(history_context.split()) > generator.model.config.n_positions - (generation_kwargs["max_length"] + context_length):
231
+ history_take += 1
232
+ history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:])
233
+ if history_take >= generator.model.config.n_positions:
234
+ break
235
+ context += history_context
236
+ for _ in range(5):
237
+ response = generator.generate(f"{context}\n\n{user}: {message}.\n", generation_kwargs)[1]
238
+ if DEBUG:
239
+ print("\n-----" + response + "-----\n")
240
+ response = response.split("\n")[-1]
241
+ if agent in response and response.split(agent)[-1]:
242
+ response = response.split(agent)[-1]
243
+ if user in response and response.split(user)[-1]:
244
+ response = response.split(user)[-1]
245
+ if response[0] in string.punctuation:
246
+ response = response[1:].strip()
247
+ if response.strip().startswith(f"{user}: {message}"):
248
+ response = response.strip().split(f"{user}: {message}")[-1]
249
+ if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip():
250
+ break
251
+ if DEBUG:
252
+ print()
253
+ print("CONTEXT:")
254
+ print(context)
255
+ print()
256
+ print("MESSAGE")
257
+ print(message)
258
+ print()
259
+ print("RESPONSE:")
260
+ print(response)
261
+ if not response.strip():
262
+ response = random.choice(["No sé muy bien cómo contestar a eso.", "No estoy seguro.", "Prefiero no contestar.", "Ni idea.", "¿Podemos cambiar de tema?"])
263
+ history.append((user_message, response))
264
+ return history, history, ""
265
+
266
+
267
  with gr.Blocks() as demo:
268
  gr.Markdown(HEADER)
269
  with gr.Row():
 
315
  # help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.",
316
  )
317
  with gr.Column():
318
+ with gr.Tabs():
319
+ with gr.TabItem("Generar"):
320
+ textbox = gr.Textbox(label="Texto", placeholder="Escriba algo (o seleccione un ejemplo) y pulse 'Generar'...", lines=8)
321
+ examples = gr.Dropdown(label="Ejemplos", choices=EXAMPLES, value=None, type="value")
322
+ hidden = gr.Textbox(visible=False, show_label=False)
323
+ with gr.Box():
324
+ # output = gr.Markdown()
325
+ output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"})
326
+ with gr.Row():
327
+ generate_btn = gr.Button("Generar")
328
+ generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output])
329
+ expand_btn = gr.Button("Añadir")
330
+ expand_btn.click(expand_with_gpt, inputs=[hidden, textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output])
331
+
332
+ edit_btn = gr.Button("Editar", variant="secondary")
333
+ edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output])
334
+ clean_btn = gr.Button("Borrar", variant="secondary")
335
+ clean_btn.click(lambda: ("", "", [], ""), inputs=[], outputs=[textbox, hidden, output, examples])
336
+ examples.change(lambda x: x, inputs=[examples], outputs=[textbox])
337
+
338
+ with gr.TabItem("Charlar") as tab_chat:
339
+ tab_chat.select(lambda: 25, inputs=[], outputs=[max_length])
340
+ context = gr.Textbox(label="Contexto", value=CONTEXT, lines=5)
341
+ with gr.Row():
342
+ agent = gr.Textbox(label="Agente", value=AGENT)
343
+ user = gr.Textbox(label="Usuario", value=USER)
344
+ history = gr.Variable(default_value=[])
345
+ chatbot = gr.Chatbot(color_map=("green", "gray"))
346
+ with gr.Row():
347
+ message = gr.Textbox(placeholder="Escriba aquí su mensaje y pulse 'Enviar'", show_label=False)
348
+ chat_btn = gr.Button("Enviar")
349
+ chat_btn.click(chat_with_gpt, inputs=[agent, user, context, message, history, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[chatbot, history, message])
350
  gr.Markdown(FOOTER)
351
 
352
+
353
+
354
  demo.launch()
355
  # gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch()