File size: 5,258 Bytes
6f8d4ee
83a2412
6f8d4ee
83a2412
 
6f8d4ee
630e57e
 
aace9fa
630e57e
 
 
 
 
 
 
aace9fa
 
630e57e
 
aace9fa
630e57e
 
 
 
 
 
 
 
 
 
5f7b72c
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995512e
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995512e
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b7b2e1
 
7e3e4fa
 
8b7b2e1
 
 
 
 
 
 
 
 
630e57e
 
 
 
 
 
 
cbe73b1
630e57e
 
 
 
 
 
 
b1f38dd
83a2412
630e57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import subprocess

# Installing flash_attn
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
               shell=True)

from threading import Thread
import spaces
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True)


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = model.config.eos_token_id
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


@spaces.GPU(duration=280)
def predict(history, prompt, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    if prompt:
        messages.append({"role": "system", "content": prompt})
    for idx, (user_msg, model_msg) in enumerate(history):
        if prompt and idx == 0:
            continue
        if idx == len(history) - 1 and not model_msg:
            query = user_msg
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
        next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True)
    eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                    tokenizer.get_command("<|observation|>")]
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1,
        "eos_token_id": eos_token_id,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    for new_token in streamer:
        if new_token and '<|user|>' not in new_token:
            history[-1][1] += new_token
        yield history


with gr.Blocks() as demo:
    gr.Markdown(
        """
        <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
            longwriter-glm4-9b Huggingface Space🤗
        </div>
        <div style="text-align: center;">
            <a href="https://huggingface.co/THUDM/LongWriter-glm4-9b">🤗 Model Hub</a> |
            <a href="https://github.com/THUDM/LongWriter">🌐 Github</a> |
            <a href="https://arxiv.org/pdf/2408.07055">📜 arxiv </a>
        </div>
        <div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;">
    <div style="color: black;">
        ⚠️ Due to the limitations of Huggingface ZERO GPUs, in order to output 5K characters in one go,
        we need to request a 4-5 minute quota each time.
        This will result in you only being able to use it once every 4 hours.
    </div>
    <br>
    <div style="color: red;">
        ⚠️ After 4-5 minutes, it will result in a timeout error, regardless of whether the output is complete. 
        This is not caused by the model.<br>
        If you plan to use it long-term, please consider deploying the model or forking this space yourself.
    </div>
</div>
        """
    )
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...(Example: Write a 10000-word China travel guide)", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")
        with gr.Column(scale=1):
            prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
            pBtn = gr.Button("Set Prompt")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 128000, value=10240, step=1.0, label="Maximum length(Input + Output)",
                                   interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


    def user(query, history):
        return "", history + [[query, ""]]


    def set_prompt(prompt_text):
        return [[prompt_text, "Set prompt successfully"]]


    pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)

    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
    )
    emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)

demo.queue()
demo.launch()