Spaces:
Runtime error
Runtime error
File size: 2,812 Bytes
a7d37a3 a6636f6 c2769e9 a6636f6 a7d37a3 c2769e9 b2dcf43 a6636f6 c2769e9 a6636f6 c2769e9 a6636f6 c2769e9 a6636f6 c2769e9 a6636f6 c2769e9 a6636f6 a7d37a3 a6636f6 a7d37a3 c2769e9 a6636f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
from threading import Thread
import gradio as gr
import torch
config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True)
model = PeftModel.from_pretrained(model, r"Junity/Genshin-World-Model", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Base", trust_remote_code=True)
history = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
model.cuda()
model = model.half()
def respond(role_name, msg, textbox):
if textbox != '':
textbox = textbox + "\n" + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
yield ["", textbox]
else:
textbox = textbox + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
yield ["", textbox]
input_ids = tokenizer.encode(textbox)[-4096:]
input_ids = torch.LongTensor([input_ids]).to(device)
generation_config = model.generation_config
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
gen_kwargs = {}
gen_kwargs.update(dict(
input_ids=input_ids,
temperature=1.0,
top_p=0.75,
repetition_penalty=1.2,
max_new_tokens=256
))
outputs = []
print(input_ids)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
textbox += new_text
yield ["", textbox]
with gr.Blocks() as demo:
gr.Markdown(
"""
## Genshin-World-Model
- 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
- 此模型不支持要求对方回答什么,只支持续写。
"""
)
with gr.Tab("聊天") as chat:
role_name = gr.Textbox(label="你将扮演的角色")
msg = gr.Textbox(label="输入")
with gr.Row():
clear = gr.Button("Clear")
sub = gr.Button("Submit")
textbox = gr.Textbox(interactive=False)
sub.click(fn=respond, inputs=[role_name, msg, textbox], outputs=[msg, textbox])
clear.click(lambda: None, None, textbox, queue=False)
demo.queue().launch(server_port=6006)
|