File size: 2,812 Bytes
a7d37a3
a6636f6
c2769e9
a6636f6
a7d37a3
c2769e9
 
b2dcf43
a6636f6
 
 
c2769e9
a6636f6
 
 
 
c2769e9
a6636f6
 
 
 
c2769e9
a6636f6
 
 
 
c2769e9
 
a6636f6
 
 
 
 
 
 
 
 
 
 
 
c2769e9
a6636f6
 
a7d37a3
a6636f6
 
 
a7d37a3
 
 
 
 
 
 
 
 
c2769e9
 
 
 
 
 
a6636f6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
from threading import Thread
import gradio as gr
import torch

config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True)
model = PeftModel.from_pretrained(model, r"Junity/Genshin-World-Model", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Base", trust_remote_code=True)
history = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
    model.cuda()
    model = model.half()

def respond(role_name, msg, textbox):
    if textbox != '':
        textbox = textbox + "\n" + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
        yield ["", textbox]
    else:
        textbox = textbox + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
        yield ["", textbox]
    input_ids = tokenizer.encode(textbox)[-4096:]
    input_ids = torch.LongTensor([input_ids]).to(device)
    generation_config = model.generation_config
    stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
    gen_kwargs = {}
    gen_kwargs.update(dict(
        input_ids=input_ids,
        temperature=1.0,
        top_p=0.75,
        repetition_penalty=1.2,
        max_new_tokens=256
    ))
    outputs = []
    print(input_ids)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs["streamer"] = streamer

    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    for new_text in streamer:
        textbox += new_text
        yield ["", textbox]

with gr.Blocks() as demo:
    gr.Markdown(
        """
        ## Genshin-World-Model
        - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
        - 此模型不支持要求对方回答什么,只支持续写。
        """
    )
    with gr.Tab("聊天") as chat:
        role_name = gr.Textbox(label="你将扮演的角色")
        msg = gr.Textbox(label="输入")
    with gr.Row():
        clear = gr.Button("Clear")
        sub = gr.Button("Submit")
    textbox = gr.Textbox(interactive=False)
    sub.click(fn=respond, inputs=[role_name, msg, textbox], outputs=[msg, textbox])
    clear.click(lambda: None, None, textbox, queue=False)
    demo.queue().launch(server_port=6006)