import json
import datasets
import gradio as gr
import torch
from backend import (get_message_single, get_message_spam, send_single,
send_spam, tokenizer)
from defaults import (ADDRESS_BETTERTRANSFORMER, ADDRESS_VANILLA,
defaults_bt_single, defaults_bt_spam,
defaults_vanilla_single, defaults_vanilla_spam)
def dispatch_single(
input_model_single, address_input_vanilla, address_input_bettertransformer
):
result_vanilla = send_single(input_model_single, address_input_vanilla)
result_bettertransformer = send_single(
input_model_single, address_input_bettertransformer
)
return result_vanilla, result_bettertransformer
def dispatch_spam_artif(
input_n_spam_artif,
sequence_length,
padding_ratio,
address_input_vanilla,
address_input_bettertransformer,
):
sequence_length = int(sequence_length)
input_n_spam_artif = int(input_n_spam_artif)
inp_tokens = torch.randint(tokenizer.vocab_size - 1, (sequence_length,)) + 1
n_pads = max(int(padding_ratio * len(inp_tokens)), 1)
inp_tokens[-n_pads:] = 0
inp_tokens[0] = 101
inp_tokens[-n_pads - 1] = 102
attention_mask = torch.zeros((sequence_length,), dtype=torch.int64)
attention_mask[:-n_pads] = 1
str_input = json.dumps(
{
"input_ids": inp_tokens.cpu().tolist(),
"attention_mask": attention_mask.cpu().tolist(),
"pre_tokenized": True,
}
)
input_dataset = datasets.Dataset.from_dict(
{"sentence": [str_input for _ in range(input_n_spam_artif)]}
)
result_vanilla = send_spam(input_dataset, address_input_vanilla)
result_bettertransformer = send_spam(input_dataset, address_input_bettertransformer)
return result_vanilla, result_bettertransformer
TTILE_IMAGE = """
"""
with gr.Blocks() as demo:
gr.HTML(TTILE_IMAGE)
gr.Markdown(
"## Speed up inference and support more workload with PyTorch's BetterTransformer 🤗"
)
gr.Markdown(
"""
**The two AWS instances powering this Space are offline (to save us the $$$). Feel free to reproduce using [this backend code](https://github.com/fxmarty/bettertransformer_demo). The example results are from an AWS EC2 g4dn.xlarge instance with a single NVIDIA T4 GPU.**
"""
)
gr.Markdown(
"""
Let's try out [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/) + [TorchServe](https://pytorch.org/serve/)!
BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks. Depending on your hardware, batch size, sequence length, padding ratio, it can bring large speedups at inference **at no cost in prediction quality**. As a one-liner, you can convert your 🤗 Transformers models to use BetterTransformer thanks to the integration in the [🤗 Optimum](https://github.com/huggingface/optimum) library:
```
from optimum.bettertransformer import BetterTransformer
better_model = BetterTransformer.transform(model)
```
This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer. The model used is [`distilbert-base-uncased-finetuned-sst-2-english`](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english), and TorchServe is parametrized to use a maximum batch size of 8. **Beware:** you may be queued in case several persons use the Space at the same time.
For more details on the TorchServe implementation and to reproduce, see [this reference code](https://github.com/fxmarty/bettertransformer_demo). For more details on BetterTransformer, check out the [blog post on PyTorch's Medium](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2), and [the Optimum documentation](https://huggingface.co/docs/optimum/bettertransformer/overview)!"""
)
gr.Markdown("""## Single input scenario
Note: BetterTransformer normally shines with batch size > 1 and some padding. So this is not the best case here. Check out the heavy workload case below as well.
""")
address_input_vanilla = gr.Textbox(
max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False
)
address_input_bettertransformer = gr.Textbox(
max_lines=1,
label="ip bettertransformer",
value=ADDRESS_BETTERTRANSFORMER,
visible=False,
)
input_model_single = gr.Textbox(
max_lines=1,
label="Text",
value="Expectations were low, enjoyment was high. Although the music was not top level, the story was well-paced.",
)
btn_single = gr.Button("Send single text request")
with gr.Row():
with gr.Column(scale=50):
gr.Markdown("### Vanilla Transformers + TorchServe")
output_single_vanilla = gr.Markdown(
label="Output single vanilla",
value=get_message_single(**defaults_vanilla_single),
)
with gr.Column(scale=50):
gr.Markdown("### BetterTransformer + TorchServe")
output_single_bt = gr.Markdown(
label="Output single bt", value=get_message_single(**defaults_bt_single)
)
btn_single.click(
fn=dispatch_single,
inputs=[
input_model_single,
address_input_vanilla,
address_input_bettertransformer,
],
outputs=[output_single_vanilla, output_single_bt],
)
gr.Markdown(
"""
**Beware that the end-to-end latency can be impacted by a different ping time between the two servers.**
## Heavy workload scenario
"""
)
input_n_spam_artif = gr.Number(
label="Number of sequences to send",
value=80,
)
sequence_length = gr.Number(
label="Sequence length (in tokens)",
value=128,
)
padding_ratio = gr.Number(
label="Padding ratio (i.e. how much of the input is padding. In the real world when batch size > 1, the token sequence is padded with 0 to have all inputs with the same length.)",
value=0.7,
)
btn_spam_artif = gr.Button("Spam text requests (using artificial data)")
with gr.Row():
with gr.Column(scale=50):
gr.Markdown("### Vanilla Transformers + TorchServe")
output_spam_vanilla_artif = gr.Markdown(
label="Output spam vanilla",
value=get_message_spam(**defaults_vanilla_spam),
)
with gr.Column(scale=50):
gr.Markdown("### BetterTransformer + TorchServe")
output_spam_bt_artif = gr.Markdown(
label="Output spam bt", value=get_message_spam(**defaults_bt_spam)
)
btn_spam_artif.click(
fn=dispatch_spam_artif,
inputs=[
input_n_spam_artif,
sequence_length,
padding_ratio,
address_input_vanilla,
address_input_bettertransformer,
],
outputs=[output_spam_vanilla_artif, output_spam_bt_artif],
)
demo.queue(concurrency_count=1)
demo.launch()