from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig import gradio as gr import torch config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("../Baichuan/models--baichuan-inc--Baichuan-13B-Base\snapshots\Baichuan-13B-Base", trust_remote_code=True) model = PeftModel.from_pretrained(model, r"../Baichuan/r64alpha32dropout0.5loss0.007/checkpoint-5000", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True) history = [] device = "cpu" def respond(role_name, msg, chatbot, character): global history if role_name is not None: history.append(role_name + ":" + msg) else: history.append(msg) total_input = [] for i, message in enumerate(history[::-1]): content_tokens = tokenizer.encode(message + '\n') total_input = content_tokens + total_input if content_tokens + total_input > 4096: break total_input = total_input[-4096:] input_ids = torch.LongTensor([total_input]).to(device) generation_config = model.generation_config stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) def stream_generator(): outputs = [] for token in model.generate(input_ids, generation_config=stream_config): outputs.append(token.item()) yield None, tokenizer.decode(outputs, skip_special_tokens=True) return stream_generator() with gr.Blocks() as demo: gr.Markdown( """ ## Genshin-World-Model - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model) - 此模型不支持要求对方回答什么,只支持续写。 """ ) with gr.Tab("聊天") as chat: role_name = gr.Textbox(label="你将扮演的角色") msg = gr.Textbox(label="输入") with gr.Row(): clear = gr.Button("Clear") sub = gr.Button("Submit") chatbot = gr.Chatbot() sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) demo.queue().launch()