Spaces:
Running
Running
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import gradio as gr | |
import torch | |
MAX_INPUT_TOKEN_LENGTH = 4096 | |
model_id = 'HuggingFaceH4/zephyr-7b-beta' | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.use_default_system_prompt = False | |
def generate(input, chat_history=[], system_prompt=False, max_new_tokens=512, temperature=0.5, top_p=0.95, top_k=50, repetition_penalty=1.2): | |
conversation = [] | |
if system_prompt: | |
conversation.append({ | |
'role': 'system', | |
'content': system_prompt | |
}) | |
for user, assistant in chat_history: | |
conversation.extend({ | |
'role': 'user', | |
'content': user | |
}, | |
{ | |
'role': 'assistant', | |
'content': assistant | |
}) | |
conversation.append({ | |
'role': 'user', | |
'content': input | |
}) | |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors='pt') | |
if input_ids.shape[1] > MAXX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{'input_ids': input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield ''.join(outputs) | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
examples=[ | |
'What is GPT?', | |
'What is Life?', | |
'Who is Alan Turing' | |
] | |
) | |
chat_interface.queue(max_size=20).launch() |