Spaces:
Running
Running
import gradio as gr | |
import json | |
import math | |
from backend import get_message_single, get_message_spam, send_single, send_spam, tokenizer | |
from defaults import ( | |
ADDRESS_BETTERTRANSFORMER, | |
ADDRESS_VANILLA, | |
defaults_bt_single, | |
defaults_bt_spam, | |
defaults_vanilla_single, | |
defaults_vanilla_spam, | |
BATCH_SIZE, | |
) | |
import datasets | |
import torch | |
def dispatch_single(input_model_single, address_input_vanilla, address_input_bettertransformer): | |
result_vanilla = send_single(input_model_single, address_input_vanilla) | |
result_bettertransformer = send_single(input_model_single, address_input_bettertransformer) | |
return result_vanilla, result_bettertransformer | |
def dispatch_spam(input_n_spam, address_input_vanilla, address_input_bettertransformer): | |
input_n_spam = int(input_n_spam) | |
assert input_n_spam <= len(data) | |
inp = data.shuffle().select(range(input_n_spam)) | |
result_vanilla = send_spam(inp, address_input_vanilla) | |
result_bettertransformer = send_spam(inp, address_input_bettertransformer) | |
return result_vanilla, result_bettertransformer | |
def dispatch_spam_artif(input_n_spam_artif, sequence_length, padding_ratio, address_input_vanilla, address_input_bettertransformer): | |
sequence_length = int(sequence_length) | |
input_n_spam_artif = int(input_n_spam_artif) | |
inp_tokens = torch.randint(tokenizer.vocab_size - 1, (sequence_length,)) + 1 | |
n_pads = max(int(padding_ratio * len(inp_tokens)), 1) | |
inp_tokens[- n_pads:] = 0 | |
inp_tokens[0] = 101 | |
inp_tokens[- n_pads - 1] = 102 | |
#inp_tokens = inp_tokens.unsqueeze(0).repeat(BATCH_SIZE, 1) | |
attention_mask = torch.zeros((sequence_length,), dtype=torch.int64) | |
attention_mask[:- n_pads] = 1 | |
str_input = json.dumps({ | |
"input_ids": inp_tokens.cpu().tolist(), | |
"attention_mask": attention_mask.cpu().tolist(), | |
"pre_tokenized": True, | |
}) | |
input_dataset = datasets.Dataset.from_dict( | |
{"sentence": [str_input for _ in range(input_n_spam_artif)]} | |
) | |
result_vanilla = send_spam(input_dataset, address_input_vanilla) | |
result_bettertransformer = send_spam(input_dataset, address_input_bettertransformer) | |
return result_vanilla, result_bettertransformer | |
TTILE_IMAGE = """ | |
<div | |
style=" | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
width: 50%; | |
" | |
> | |
<img src="https://huggingface.co/spaces/fxmarty/bettertransformer-demo/resolve/main/header.webp"/> | |
</div> | |
""" | |
TITLE = """ | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
text-align: center; | |
max-width: 1400px; | |
gap: 0.8rem; | |
font-size: 2.2rem; | |
" | |
> | |
<h1 style="font-weight: 500; margin-bottom: 10px; margin-top: 10px;"> | |
Speed up your inference and support more workload with PyTorch's BetterTransformer 🤗 | |
</h1> | |
</div> | |
""" | |
with gr.Blocks() as demo: | |
gr.HTML(TTILE_IMAGE) | |
gr.HTML(TITLE) | |
gr.Markdown( | |
""" | |
Let's try out TorchServe + BetterTransformer! | |
BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks. | |
As a one-liner, you can convert your 🤗 Transformers models to use BetterTransformer thanks to the [🤗 Optimum](https://huggingface.co/docs/optimum/main/en/index) library: | |
``` | |
from optimum.bettertransformer import BetterTransformer | |
better_model = BetterTransformer.transform(model) | |
``` | |
This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer. | |
## Inference using... | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=50): | |
gr.Markdown("### Vanilla Transformers + TorchServe") | |
with gr.Column(scale=50): | |
gr.Markdown("### BetterTransformer + TorchServe") | |
address_input_vanilla = gr.Textbox( | |
max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False | |
) | |
address_input_bettertransformer = gr.Textbox( | |
max_lines=1, | |
label="ip bettertransformer", | |
value=ADDRESS_BETTERTRANSFORMER, | |
visible=False, | |
) | |
input_model_single = gr.Textbox( | |
max_lines=1, | |
label="Text", | |
value="Expectations were low, enjoyment was high", | |
) | |
btn_single = gr.Button("Send single text request") | |
with gr.Row(): | |
with gr.Column(scale=50): | |
output_single_vanilla = gr.Markdown( | |
label="Output single vanilla", | |
value=get_message_single(**defaults_vanilla_single), | |
) | |
with gr.Column(scale=50): | |
output_single_bt = gr.Markdown( | |
label="Output single bt", value=get_message_single(**defaults_bt_single) | |
) | |
btn_single.click( | |
fn=dispatch_single, | |
inputs=[input_model_single, address_input_vanilla, address_input_bettertransformer], | |
outputs=[output_single_vanilla, output_single_bt], | |
) | |
input_n_spam_artif = gr.Number( | |
label="Number of inputs to send", | |
value=8, | |
) | |
sequence_length = gr.Number( | |
label="Sequence length (in tokens)", | |
value=128, | |
) | |
padding_ratio = gr.Number( | |
label="Padding ratio", | |
value=0.5, | |
) | |
btn_spam_artif = gr.Button( | |
"Spam text requests (using artificial data)" | |
) | |
with gr.Row(): | |
with gr.Column(scale=50): | |
output_spam_vanilla_artif = gr.Markdown( | |
label="Output spam vanilla", | |
value=get_message_spam(**defaults_vanilla_spam), | |
) | |
with gr.Column(scale=50): | |
output_spam_bt_artif = gr.Markdown( | |
label="Output spam bt", value=get_message_spam(**defaults_bt_spam) | |
) | |
btn_spam_artif.click( | |
fn=dispatch_spam_artif, | |
inputs=[input_n_spam_artif, sequence_length, padding_ratio, address_input_vanilla, address_input_bettertransformer], | |
outputs=[output_spam_vanilla_artif, output_spam_bt_artif], | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch() |