File size: 2,528 Bytes
3cd2d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import cast

import gradio as gr
import torch
from transformers import BertTokenizerFast, ErnieForCausalLM


def load_model():
    tokenizer = BertTokenizerFast.from_pretrained("wybxc/new-yiri")
    assert isinstance(tokenizer, BertTokenizerFast)
    model = ErnieForCausalLM.from_pretrained("wybxc/new-yiri")
    assert isinstance(model, ErnieForCausalLM)

    return tokenizer, model


def generate(
    tokenizer: BertTokenizerFast,
    model: ErnieForCausalLM,
    input_str: str,
    alpha: float,
    topk: int,
):
    input_ids = tokenizer.encode(input_str, return_tensors="pt")
    input_ids = cast(torch.Tensor, input_ids)
    outputs = model.generate(
        input_ids,
        max_new_tokens=100,
        penalty_alpha=alpha,
        top_k=topk,
        early_stopping=True,
        decoder_start_token_id=tokenizer.sep_token_id,
        eos_token_id=tokenizer.sep_token_id,
    )
    i, *_ = torch.nonzero(outputs[0] == tokenizer.sep_token_id)
    output = tokenizer.decode(
        outputs[0, i:],
        skip_special_tokens=True,
    ).replace(" ", "")
    return output


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot().style(height=500)
            with gr.Row():
                with gr.Column(scale=4):
                    msg = gr.Textbox(
                        show_label=False, placeholder="Enter text and press enter"
                    ).style(container=False)
                    msg = cast(gr.Textbox, msg)
                with gr.Column(scale=1):
                    button = gr.Button("Generate")
                with gr.Column(scale=1):
                    clear = gr.Button("Clear")
        with gr.Column(scale=1):
            alpha = gr.Slider(0, 1, 0.5, step=0.01, label="Penalty Alpha")
            topk = gr.Slider(1, 50, 5, step=1, label="Top K")

    tokenizer, model = load_model()

    def on_message(
        user_message: str, history: list[list[str]], alpha: float, topk: int
    ):
        bot_message = generate(
            tokenizer,
            model,
            user_message,
            alpha=alpha,
            topk=topk,
        )
        return "", [*history, [user_message, bot_message]]

    msg.submit(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])
    button.click(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])

    clear.click(lambda: None, None, chatbot)

if __name__ == "__main__":
    demo.queue(concurrency_count=3)
    demo.launch()