File size: 5,001 Bytes
a7d37a3
a6636f6
c2769e9
a6636f6
a7d37a3
c2769e9
 
 
b00538d
 
 
 
 
 
 
f595202
 
 
 
b00538d
 
f595202
 
 
 
b00538d
 
 
f595202
 
b00538d
 
 
a6636f6
b00538d
 
 
 
 
 
a6636f6
c2769e9
b00538d
 
 
 
 
 
a6636f6
b00538d
 
05b7bb9
a6636f6
c2769e9
 
a6636f6
 
 
05b7bb9
 
 
 
 
d04c3ba
a6636f6
 
 
 
 
c2769e9
a6636f6
 
a7d37a3
a6636f6
 
 
a7d37a3
b00538d
a7d37a3
 
 
 
 
 
b00538d
a7d37a3
 
b00538d
 
 
 
 
 
 
c2769e9
b00538d
 
 
 
 
 
 
 
 
 
b64802f
 
 
 
 
10a57ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
from threading import Thread
import gradio as gr
import torch


# lora_folder = ''
# model_folder = ''
#
# config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == ''
#                                      else lora_folder),
#                                     trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
#                                               else model_folder),
#                                               torch_dtype=torch.float16,
#                                               device_map="auto",
#                                               trust_remote_code=True)
# model = PeftModel.from_pretrained(model,
#                                   ("Junity/Genshin-World-Model" if lora_folder == ''
#                                    else lora_folder),
#                                    device_map="auto",
#                                    torch_dtype=torch.float32,
#                                    trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
#                                            else model_folder),
#                                           trust_remote_code=True)
history = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
    if textbox != '':
        textbox = (textbox
                   + "\n"
                   + role_name
                   + (":" if role_name != '' else '')
                   + msg
                   + ('。\n' if msg[-1] not in ['。', '!', '?'] else ''))
        yield ["", textbox]
    else:
        textbox = (textbox
                   + role_name
                   + (":" if role_name != '' else '')
                   + msg
                   + ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '')
                   + ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else ''))
        yield ["", textbox]
    if character_name != '':
        textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
    input_ids = tokenizer.encode(textbox)[-3200:]
    input_ids = torch.LongTensor([input_ids]).to(device)
    generation_config = model.generation_config
    stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
    gen_kwargs = {}
    gen_kwargs.update(dict(
        input_ids=input_ids,
        temperature=temp,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=rep,
        max_new_tokens=max_len,
        do_sample=True,
    ))
    outputs = []
    print(input_ids)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs["streamer"] = streamer

    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    for new_text in streamer:
        textbox += new_text
        yield ["", textbox]


with gr.Blocks() as demo:
    gr.Markdown(
        """
        ## Genshin-World-Model
        - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
        - 此模型不支持要求对方回答什么,只支持续写。
        - 目前运行不了,因为没有钱租卡。
        """
    )
    with gr.Tab("创作") as chat:
        role_name = gr.Textbox(label="你将扮演的角色(可留空)")
        character_name = gr.Textbox(label="对方的角色(可留空)")
        msg = gr.Textbox(label="你说的话")
    with gr.Row():
        clear = gr.ClearButton()
        sub = gr.Button("Submit", variant="primary")
    with gr.Row():
        temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True)
        rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True)
        max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True)
        top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True)
        top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True)
    textbox = gr.Textbox(interactive=True, label="全部文本(可修改)")
    clear.add([msg, role_name, textbox])
    sub.click(fn=respond,
              inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k],
              outputs=[msg, textbox])
    gr.Markdown(
        """
        #### 特别鸣谢 XXXX
        """
    )
    demo.queue().launch()