File size: 2,528 Bytes
3cd2d35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from typing import cast
import gradio as gr
import torch
from transformers import BertTokenizerFast, ErnieForCausalLM
def load_model():
tokenizer = BertTokenizerFast.from_pretrained("wybxc/new-yiri")
assert isinstance(tokenizer, BertTokenizerFast)
model = ErnieForCausalLM.from_pretrained("wybxc/new-yiri")
assert isinstance(model, ErnieForCausalLM)
return tokenizer, model
def generate(
tokenizer: BertTokenizerFast,
model: ErnieForCausalLM,
input_str: str,
alpha: float,
topk: int,
):
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = cast(torch.Tensor, input_ids)
outputs = model.generate(
input_ids,
max_new_tokens=100,
penalty_alpha=alpha,
top_k=topk,
early_stopping=True,
decoder_start_token_id=tokenizer.sep_token_id,
eos_token_id=tokenizer.sep_token_id,
)
i, *_ = torch.nonzero(outputs[0] == tokenizer.sep_token_id)
output = tokenizer.decode(
outputs[0, i:],
skip_special_tokens=True,
).replace(" ", "")
return output
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column(scale=4):
msg = gr.Textbox(
show_label=False, placeholder="Enter text and press enter"
).style(container=False)
msg = cast(gr.Textbox, msg)
with gr.Column(scale=1):
button = gr.Button("Generate")
with gr.Column(scale=1):
clear = gr.Button("Clear")
with gr.Column(scale=1):
alpha = gr.Slider(0, 1, 0.5, step=0.01, label="Penalty Alpha")
topk = gr.Slider(1, 50, 5, step=1, label="Top K")
tokenizer, model = load_model()
def on_message(
user_message: str, history: list[list[str]], alpha: float, topk: int
):
bot_message = generate(
tokenizer,
model,
user_message,
alpha=alpha,
topk=topk,
)
return "", [*history, [user_message, bot_message]]
msg.submit(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])
button.click(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])
clear.click(lambda: None, None, chatbot)
if __name__ == "__main__":
demo.queue(concurrency_count=3)
demo.launch()
|