from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig from threading import Thread import gradio as gr import torch config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True) model = PeftModel.from_pretrained(model, r"Junity/Genshin-World-Model", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Base", trust_remote_code=True) history = [] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device == "cuda": model.cuda() model = model.half() def respond(role_name, msg, textbox): if textbox != '': textbox = textbox + "\n" + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n' yield ["", textbox] else: textbox = textbox + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n' yield ["", textbox] input_ids = tokenizer.encode(textbox)[-4096:] input_ids = torch.LongTensor([input_ids]).to(device) generation_config = model.generation_config stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) gen_kwargs = {} gen_kwargs.update(dict( input_ids=input_ids, temperature=1.0, top_p=0.75, repetition_penalty=1.2, max_new_tokens=256 )) outputs = [] print(input_ids) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() for new_text in streamer: textbox += new_text yield ["", textbox] with gr.Blocks() as demo: gr.Markdown( """ ## Genshin-World-Model - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model) - 此模型不支持要求对方回答什么,只支持续写。 """ ) with gr.Tab("聊天") as chat: role_name = gr.Textbox(label="你将扮演的角色") msg = gr.Textbox(label="输入") with gr.Row(): clear = gr.Button("Clear") sub = gr.Button("Submit") textbox = gr.Textbox(interactive=False) sub.click(fn=respond, inputs=[role_name, msg, textbox], outputs=[msg, textbox]) clear.click(lambda: None, None, textbox, queue=False) demo.queue().launch(server_port=6006)