Felix Marty
style
5a87989
raw
history blame
5.34 kB
import gradio as gr
from backend import get_message_single, get_message_spam, send_single, send_spam
from defaults import (
ADDRESS_BETTERTRANSFORMER,
ADDRESS_VANILLA,
defaults_bt_single,
defaults_bt_spam,
defaults_vanilla_single,
defaults_vanilla_spam,
)
TTILE_IMAGE = """
<div
style="
display: block;
margin-left: auto;
margin-right: auto;
width: 50%;
"
>
<img src="https://huggingface.co/spaces/fxmarty/bettertransformer-demo/resolve/main/header.webp"/>
</div>
"""
TITLE = """
<div
style="
display: inline-flex;
align-items: center;
text-align: center;
max-width: 1400px;
gap: 0.8rem;
font-size: 2.2rem;
"
>
<h1 style="font-weight: 700; margin-bottom: 10px; margin-top: 10px;">
Speed up your inference and support more workload with PyTorch's BetterTransformer 🤗
</h1>
</div>
"""
with gr.Blocks() as demo:
gr.HTML(TTILE_IMAGE)
gr.HTML(TITLE)
gr.Markdown(
"""
Let's try out TorchServe + BetterTransformer!
BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks.
As a one-liner, you can convert your 🤗 Transformers models to use BetterTransformer thanks to the [🤗 Optimum](https://huggingface.co/docs/optimum/main/en/index) library:
```
from optimum.bettertransformer import BetterTransformer
better_model = BetterTransformer.transform(model)
```
This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer.
## Inference using...
"""
)
with gr.Row():
with gr.Column(scale=50):
gr.Markdown("### Vanilla Transformers + TorchServe")
address_input_vanilla = gr.Textbox(
max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False
)
input_model_vanilla = gr.Textbox(
max_lines=1,
label="Text",
value="Expectations were low, enjoyment was high",
)
btn_single_vanilla = gr.Button("Send single text request")
output_single_vanilla = gr.Markdown(
label="Output single vanilla",
value=get_message_single(**defaults_vanilla_single),
)
with gr.Column():
with gr.Column(scale=40):
input_n_inputs_vanilla = gr.Textbox(
max_lines=1,
label="Number of inputs",
value=8,
)
with gr.Column(scale=60):
gr.Markdown("")
btn_spam_vanilla = gr.Button(
"Spam text requests (from sst2 validation set)"
)
output_spam_vanilla = gr.Markdown(
label="Output spam vanilla",
value=get_message_spam(**defaults_vanilla_spam),
)
btn_single_vanilla.click(
fn=send_single,
inputs=[input_model_vanilla, address_input_vanilla],
outputs=output_single_vanilla,
)
btn_spam_vanilla.click(
fn=send_spam,
inputs=[address_input_vanilla],
outputs=output_spam_vanilla,
)
with gr.Column(scale=50):
gr.Markdown("### BetterTransformer + TorchServe")
address_input_bettertransformer = gr.Textbox(
max_lines=1,
label="ip bettertransformer",
value=ADDRESS_BETTERTRANSFORMER,
visible=False,
)
input_model_bettertransformer = gr.Textbox(
max_lines=1,
label="Text",
value="Expectations were low, enjoyment was high",
)
btn_single_bt = gr.Button("Send single text request")
output_single_bt = gr.Markdown(
label="Output single bt", value=get_message_single(**defaults_bt_single)
)
with gr.Row():
with gr.Column(scale=40):
input_n_inputs_bt = gr.Textbox(
max_lines=1,
label="Number of inputs",
value=8,
)
with gr.Column(scale=60):
gr.Markdown("")
btn_spam_bt = gr.Button("Spam text requests (from sst2 validation set)")
output_spam_bt = gr.Markdown(
label="Output spam bt", value=get_message_spam(**defaults_bt_spam)
)
btn_single_bt.click(
fn=send_single,
inputs=[input_model_bettertransformer, address_input_bettertransformer],
outputs=output_single_bt,
)
btn_spam_bt.click(
fn=send_spam,
inputs=[address_input_bettertransformer],
outputs=output_spam_bt,
)
demo.queue(concurrency_count=1)
demo.launch()