import os import random import string import gradio as gr import torch from transformers import pipeline, set_seed from transformers import AutoTokenizer, AutoModelForCausalLM import logging logger = logging.getLogger() logger.addHandler(logging.StreamHandler()) DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1" HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0 if DEVICE != "cpu" and not torch.cuda.is_available(): DEVICE = "cpu" logger.info(f"DEVICE {DEVICE}") DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B") MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) HEADER_INFO = """ # BERTIN GPT-J-6B Spanish BERTIN GPT-J-6B Model. """.strip() LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png" HEADER = f"""
' # f'{text} ' # f'{generated_text}' # f'
' # ) #@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}) #@st.cache(allow_output_mutation=True) #@st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None}) def load_text_generator(): text_generator = TextGeneration() text_generator.load() return text_generator cleaner = Normalizer() generator = load_text_generator() def complete_with_gpt(text, max_length, top_k, top_p, temperature, do_sample, do_clean): generation_kwargs = { "max_length": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature, "do_sample": do_sample, "do_clean": do_clean, } return generator.generate(text, generation_kwargs) def expand_with_gpt(hidden, text, max_length, top_k, top_p, temperature, do_sample, do_clean): generation_kwargs = { "max_length": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature, "do_sample": do_sample, "do_clean": do_clean, } return generator.generate(hidden or text, generation_kwargs) def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean): # agent = AGENT # user = USER generation_kwargs = { "max_length": 25, "top_k": top_k, "top_p": top_p, "temperature": temperature, "do_sample": do_sample, "do_clean": do_clean, # "num_return_sequences": 1, # "return_full_text": False, } message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1] history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")] context = context.format(USER=user or USER, AGENT=agent or AGENT).strip() if context[-1] not in ".:": context += "." context_length = len(context.split()) history_take = 0 history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) while len(history_context.split()) > generator.model.config.n_positions - (generation_kwargs["max_length"] + context_length): history_take += 1 history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) if history_take >= generator.model.config.n_positions: break context += history_context for _ in range(5): response = generator.generate(f"{context}\n\n{user}: {message}.\n", generation_kwargs)[1] if DEBUG: print("\n-----" + response + "-----\n") response = response.split("\n")[-1] if agent in response and response.split(agent)[-1]: response = response.split(agent)[-1] if user in response and response.split(user)[-1]: response = response.split(user)[-1] if response[0] in string.punctuation: response = response[1:].strip() if response.strip().startswith(f"{user}: {message}"): response = response.strip().split(f"{user}: {message}")[-1] if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip(): break if DEBUG: print() print("CONTEXT:") print(context) print() print("MESSAGE") print(message) print() print("RESPONSE:") print(response) if not response.strip(): response = random.choice(["No sé muy bien cómo contestar a eso.", "No estoy seguro.", "Prefiero no contestar.", "Ni idea.", "¿Podemos cambiar de tema?"]) history.append((user_message, response)) return history, history, "" with gr.Blocks() as demo: gr.Markdown(HEADER) with gr.Row(): with gr.Group(): with gr.Box(): gr.Markdown("Opciones") max_length = gr.Slider( label='Longitud máxima', # help="Número máximo (aproximado) de palabras a generar.", minimum=1, maximum=MAX_LENGTH, value=50, step=1 ) top_k = gr.Slider( label='Top-k', # help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`", minimum=40, maximum=80, value=50, step=1 ) top_p = gr.Slider( label='Top-p', # help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.", minimum=0.0, maximum=1.0, value=0.95, step=0.01 ) temperature = gr.Slider( label='Temperatura', # help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.", minimum=0.1, maximum=10.0, value=0.8, step=0.05 ) do_sample = gr.Checkbox( label='¿Muestrear?', value = True, # options=(True, False), # help="Si no se muestrea se usará una decodificación voraz (_greedy_).", ) do_clean = gr.Checkbox( label='¿Limpiar texto?', value = True, # options=(True, False), # help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.", ) with gr.Column(): with gr.Tabs(): with gr.TabItem("Generar"): textbox = gr.Textbox(label="Texto", placeholder="Escriba algo (o seleccione un ejemplo) y pulse 'Generar'...", lines=8) examples = gr.Dropdown(label="Ejemplos", choices=EXAMPLES, value=None, type="value") hidden = gr.Textbox(visible=False, show_label=False) with gr.Box(): # output = gr.Markdown() output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"}) with gr.Row(): generate_btn = gr.Button("Generar") generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output]) expand_btn = gr.Button("Añadir") expand_btn.click(expand_with_gpt, inputs=[hidden, textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output]) edit_btn = gr.Button("Editar", variant="secondary") edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output]) clean_btn = gr.Button("Borrar", variant="secondary") clean_btn.click(lambda: ("", "", [], ""), inputs=[], outputs=[textbox, hidden, output, examples]) examples.change(lambda x: x, inputs=[examples], outputs=[textbox]) with gr.TabItem("Charlar") as tab_chat: tab_chat.select(lambda: 25, inputs=[], outputs=[max_length]) context = gr.Textbox(label="Contexto", value=CONTEXT, lines=5) with gr.Row(): agent = gr.Textbox(label="Agente", value=AGENT) user = gr.Textbox(label="Usuario", value=USER) history = gr.Variable(default_value=[]) chatbot = gr.Chatbot(color_map=("green", "gray")) with gr.Row(): message = gr.Textbox(placeholder="Escriba aquí su mensaje y pulse 'Enviar'", show_label=False) chat_btn = gr.Button("Enviar") chat_btn.click(chat_with_gpt, inputs=[agent, user, context, message, history, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[chatbot, history, message]) gr.Markdown(FOOTER) demo.launch() # gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch()