from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig from threading import Thread import gradio as gr import torch # lora_folder = '' # model_folder = '' # # config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == '' # else lora_folder), # trust_remote_code=True) # model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == '' # else model_folder), torch_dtype=torch.float32, trust_remote_code=True) # model = PeftModel.from_pretrained(model, # ("Junity/Genshin-World-Model" if lora_folder == '' # else lora_folder) # , torch_dtype=torch.float32, trust_remote_code=True) # tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == '' # else model_folder), # trust_remote_code=True) # history = [] # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if device == "cuda": # model.cuda() # model = model.half() def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k): if textbox != '': textbox = (textbox + "\n" + role_name + (":" if role_name != '' else '') + msg + ('。\n' if msg[-1] not in ['。', '!', '?'] else '')) yield ["", textbox] else: textbox = (textbox + role_name + (":" if role_name != '' else '') + msg + ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '') + ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else '')) yield ["", textbox] if character_name != '': textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':' input_ids = tokenizer.encode(textbox)[-4096:] input_ids = torch.LongTensor([input_ids]).to(device) generation_config = model.generation_config stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) gen_kwargs = {} gen_kwargs.update(dict( input_ids=input_ids, temperature=1.5, top_p=0.7, top_k=50, repetition_penalty=1.0, max_new_tokens=256, do_sample=True, )) outputs = [] print(input_ids) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() for new_text in streamer: textbox += new_text yield ["", textbox] with gr.Blocks() as demo: gr.Markdown( """ ## Genshin-World-Model - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model) - 此模型不支持要求对方回答什么,只支持续写。 - 目前运行不了,因为没有钱租卡。 """ ) with gr.Tab("创作") as chat: role_name = gr.Textbox(label="你将扮演的角色(可留空)") character_name = gr.Textbox(label="对方的角色(可留空)") msg = gr.Textbox(label="你说的话") with gr.Row(): clear = gr.ClearButton() sub = gr.Button("Submit", variant="primary") with gr.Row(): temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True) rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True) max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True) top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True) top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True) textbox = gr.Textbox(interactive=True, label="全部文本(可修改)") clear.add([msg, role_name, textbox]) sub.click(fn=respond, inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k], outputs=[msg, textbox]) demo.queue().launch(server_port=6006)