Blossom-9B-Demo / app.py
Azure99's picture
Update app.py
454a3f9 verified
raw
history blame contribute delete
No virus
3.71 kB
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_INPUT_LIMIT = 3584
MAX_NEW_TOKENS = 1536
MODEL_NAME = "Azure99/blossom-v5.1-9b"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def get_input_ids(inst, history):
prefix = ("A chat between a human and an artificial intelligence bot. "
"The bot gives helpful, detailed, and polite answers to the human's questions.")
patterns = []
for conv in history:
patterns.append(f'\n|Human|: {conv[0]}\n|Bot|: ')
patterns.append(f'{conv[1]}')
patterns.append(f'\n|Human|: {inst}\n|Bot|: ')
patterns[0] = prefix + patterns[0]
input_ids = []
for i, pattern in enumerate(patterns):
input_ids += tokenizer.encode(pattern, add_special_tokens=(i == 0))
if i % 2 == 1:
input_ids += [tokenizer.eos_token_id]
return input_ids
def generate(generation_kwargs):
with torch.no_grad():
Thread(target=model.generate, kwargs=generation_kwargs).start()
@spaces.GPU
def chat(inst, history, temperature, top_p, repetition_penalty):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
input_ids = get_input_ids(inst, history)
if len(input_ids) > MAX_INPUT_LIMIT:
yield "The input is too long, please clear the history."
return
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device),
streamer=streamer, do_sample=True, max_new_tokens=MAX_NEW_TOKENS,
temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
generate(generation_kwargs)
outputs = ""
for new_text in streamer:
outputs += new_text
yield outputs
additional_inputs = [
gr.Slider(
label="Temperature",
value=0.5,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Controls randomness in choosing words.",
),
gr.Slider(
label="Top-P",
value=0.85,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Picks words until their combined probability is at least top_p.",
),
gr.Slider(
label="Repetition penalty",
value=1.05,
minimum=1.0,
maximum=1.2,
step=0.01,
interactive=True,
info="Repetition Penalty: Controls how much repetition is penalized.",
)
]
gr.ChatInterface(chat,
chatbot=gr.Chatbot(show_label=False, height=500, show_copy_button=True, render_markdown=True),
textbox=gr.Textbox(placeholder="", container=False, scale=7),
title="Blossom 9B Demo",
description='Hello, I am Blossom, an open source conversational large language model.🌠'
'<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
theme="soft",
examples=[["Hello"], ["What is MBTI"], ["用Python实现二分查找"],
["为switch写一篇小红书种草文案,带上emoji"]],
cache_examples=False,
additional_inputs=additional_inputs,
additional_inputs_accordion=gr.Accordion(label="Config", open=True),
clear_btn="🗑️Clear",
undo_btn="↩️Undo",
retry_btn="🔄Retry",
submit_btn="➡️Submit",
).queue().launch()