zephyr_chatbot / app.py
ubermenchh's picture
Create app.py
8400d16
raw
history blame
2.15 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gradio as gr
import torch
MAX_INPUT_TOKEN_LENGTH = 4096
model_id = 'HuggingFaceH4/zephyr-7b-beta'
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
def generate(input, chat_history=[], system_prompt=False, max_new_tokens=512, temperature=0.5, top_p=0.95, top_k=50, repetition_penalty=1.2):
conversation = []
if system_prompt:
conversation.append({
'role': 'system',
'content': system_prompt
})
for user, assistant in chat_history:
conversation.extend({
'role': 'user',
'content': user
},
{
'role': 'assistant',
'content': assistant
})
conversation.append({
'role': 'user',
'content': input
})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors='pt')
if input_ids.shape[1] > MAXX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{'input_ids': input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield ''.join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
'What is GPT?',
'What is Life?',
'Who is Alan Turing'
]
)
chat_interface.queue(max_size=20).launch()