import gradio as gr import torch import random import time from transformers import pipeline # model_name = '../checkpoint/koalpaca/ajoublue-gpt2-medium/epoch-4-last/' generator = pipeline( 'text-generation', model="heegyu/gorani-v0", device="cuda:0" if torch.cuda.is_available() else 'cpu' ) def query(message, chat_history, max_turn=2): # prompt = [ # " 넌 한국어 챗봇 고라니야. 너는 내가 묻는 질문에 답하고 지시사항에 맞는 대답을 해야해.", # " 네, 저는 한국어 챗봇 고라니입니다. 궁금한 것을 물어보세요. " # ] prompt = [] if len(chat_history) > max_turn: chat_history = chat_history[-max_turn:] for i, (user, bot) in enumerate(chat_history): # if i == 0: # prompt.append(f" 반가워 너는 한국어 챗봇이고 이름은 고라니야. {user}") # else: prompt.append(f" {user}") prompt.append(f" {bot}") prompt.append(f" {message}") prompt = "\n".join(prompt) + "\n" output = generator( prompt, do_sample=True, top_p=0.9, early_stopping=True, max_new_tokens=256, )[0]['generated_text'] print(output) response = output[len(prompt):] return response.strip() with gr.Blocks() as demo: chatbot = gr.Chatbot().style(height=700) msg = gr.Textbox() clear = gr.Button("Clear") def respond(message, chat_history): bot_message = query(message, chat_history) #random.choice(["How are you?", "I love you", "I'm very hungry"]) chat_history.append((message, bot_message)) # time.sleep(1) return "", chat_history msg.submit(respond, [msg, chatbot], [msg, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) demo.launch()