File size: 4,835 Bytes
ae25e3a
47b54c6
ae25e3a
47b54c6
ae25e3a
057dc4f
310cea3
 
a78bf18
d498a70
a78bf18
47b54c6
 
310cea3
324a277
 
47b54c6
 
 
a78bf18
47b54c6
00b02de
47b54c6
00b02de
47b54c6
 
 
ae25e3a
47b54c6
ae25e3a
47b54c6
 
ae25e3a
 
 
 
 
 
 
 
 
47b54c6
 
ae25e3a
47b54c6
 
ae25e3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47b54c6
 
ae25e3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47b54c6
ae25e3a
47b54c6
ae25e3a
 
 
47b54c6
ae25e3a
47b54c6
ae25e3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html

tokenizer = AutoTokenizer.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="")
model = AutoModel.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="").float().cuda()
# tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
# model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
# chatglm-6b-int4 cuda,本地可以运行成功
# tokenizer = AutoTokenizer.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="")
# model = AutoModel.from_pretrained(".\\models\\chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()


# chatglm-6b-int4 CPU,
# tokenizer = AutoTokenizer.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="")
# model = AutoModel.from_pretrained("models/chatglm-6b-int4", trust_remote_code=True, revision="").float()



# chatglm-6b
# kernel_file = "./models/chatglm-6b-int4/quantization_kernels.so"
# tokenizer = AutoTokenizer.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="")
# model = AutoModel.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="").half().cuda()
# model = AutoModel.from_pretrained("./models/chatglm-6b-int4", trust_remote_code=True, revision="").float()



model = model.eval()

"""Override Chatbot.postprocess"""


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history):
    chatbot.append((parse_text(input), ""))
    for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
                                               temperature=temperature):
        chatbot[-1] = (parse_text(input), parse_text(response))

        yield chatbot, history


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], []


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM</h1>""")

    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    history = gr.State([])

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
                    show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)