abidlabs's picture
abidlabs HF staff
Update app.py
6f3b818
raw
history blame
2.44 kB
import gradio as gr
import transformers
from torch import bfloat16
# from dotenv import load_dotenv # if you wanted to adapt this for a repo that uses auth
from threading import Thread
#HF_AUTH = os.getenv('HF_AUTH')
model_id = "stabilityai/StableBeluga-7B"
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16
)
model_config = transformers.AutoConfig.from_pretrained(
model_id,
#use_auth_token=HF_AUTH
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
config=model_config,
quantization_config=bnb_config,
device_map='auto',
#use_auth_token=HF_AUTH
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id,
#use_auth_token=HF_AUTH
)
DESCRIPTION = """
# Stable Beluga 7B Chat
This is a streaming Chat Interface implementation of [StableBeluga-7B](https://huggingface.co/stabilityai/StableBeluga-7B). We'll use it to deploy a Discord bot that you can add to your server!
Sometimes the model doesn't appropriately hit its stop token. Feel free to hit "stop" and "retry" if this happens to you. Or PR a fix to stop the stream if the tokens for User: get hit or something.
"""
system_prompt = "You are helpful AI."
def prompt_build(system_prompt, user_inp, hist):
prompt = f"""### System:\n{system_prompt}\n\n"""
for pair in hist:
prompt += f"""### User:\n{pair[0]}\n\n### Assistant:\n{pair[1]}\n\n"""
prompt += f"""### User:\n{user_inp}\n\n### Assistant:"""
return prompt
def chat(user_input, history):
prompt = prompt_build(system_prompt, user_input, history)
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=2048,
do_sample=True,
top_p=0.95,
temperature=0.8,
top_k=50
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
model_output = ""
for new_text in streamer:
model_output += new_text
yield model_output
return model_output
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
chatbot = gr.ChatInterface(fn=chat)
demo.queue().launch()