LinkangZhan
fix gpu version
f595202
raw
history blame
5 kB
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
from threading import Thread
import gradio as gr
import torch
# lora_folder = ''
# model_folder = ''
#
# config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == ''
# else lora_folder),
# trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
# else model_folder),
# torch_dtype=torch.float16,
# device_map="auto",
# trust_remote_code=True)
# model = PeftModel.from_pretrained(model,
# ("Junity/Genshin-World-Model" if lora_folder == ''
# else lora_folder),
# device_map="auto",
# torch_dtype=torch.float32,
# trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
# else model_folder),
# trust_remote_code=True)
history = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
if textbox != '':
textbox = (textbox
+ "\n"
+ role_name
+ (":" if role_name != '' else '')
+ msg
+ ('。\n' if msg[-1] not in ['。', '!', '?'] else ''))
yield ["", textbox]
else:
textbox = (textbox
+ role_name
+ (":" if role_name != '' else '')
+ msg
+ ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '')
+ ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else ''))
yield ["", textbox]
if character_name != '':
textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
input_ids = tokenizer.encode(textbox)[-3200:]
input_ids = torch.LongTensor([input_ids]).to(device)
generation_config = model.generation_config
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
gen_kwargs = {}
gen_kwargs.update(dict(
input_ids=input_ids,
temperature=temp,
top_p=top_p,
top_k=top_k,
repetition_penalty=rep,
max_new_tokens=max_len,
do_sample=True,
))
outputs = []
print(input_ids)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
textbox += new_text
yield ["", textbox]
with gr.Blocks() as demo:
gr.Markdown(
"""
## Genshin-World-Model
- 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
- 此模型不支持要求对方回答什么,只支持续写。
- 目前运行不了,因为没有钱租卡。
"""
)
with gr.Tab("创作") as chat:
role_name = gr.Textbox(label="你将扮演的角色(可留空)")
character_name = gr.Textbox(label="对方的角色(可留空)")
msg = gr.Textbox(label="你说的话")
with gr.Row():
clear = gr.ClearButton()
sub = gr.Button("Submit", variant="primary")
with gr.Row():
temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True)
rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True)
max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True)
top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True)
top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True)
textbox = gr.Textbox(interactive=True, label="全部文本(可修改)")
clear.add([msg, role_name, textbox])
sub.click(fn=respond,
inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k],
outputs=[msg, textbox])
gr.Markdown(
"""
#### 特别鸣谢 XXXX
"""
)
demo.queue().launch()