File size: 4,619 Bytes
227a9e8 38f2963 227a9e8 38f2963 3caea9e 38f2963 3caea9e 38f2963 3caea9e 38f2963 3caea9e 38f2963 3caea9e 38f2963 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
import gradio as gr
import mdtex2html
import torch
model = AutoModelForCausalLM.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat"
)
tokenizer = AutoTokenizer.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
use_fast=False,
trust_remote_code=True
)
model = model.quantize(8).cuda()
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>"+line
text = "".join(lines)
return text
stream = True
def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
model.generation_config.temperature = temperature
model.generation_config.top_p = top_p
model.generation_config.max_new_tokens = max_length
chatbot.append((parse_text(input), ""))
history.append({"role": "user", "content": parse_text(input)})
if stream:
position = 0
for response in model.chat(tokenizer, history, stream=True):
chatbot[-1] = (parse_text(input), parse_text(response))
if torch.backends.mps.is_available():
torch.mps.empty_cache()
yield chatbot, history, past_key_values
print(response)
history.append({"role": "assistant", "content": response})
else:
response = model.chat(tokenizer, history)
print(response)
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history, past_key_values
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], [], None
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">BaiChuan-13B-Int8</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.85, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.9, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
past_key_values = gr.State(None)
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
[chatbot, history, past_key_values], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
demo.queue().launch(share=False, inbrowser=True)
|