import json import datasets import gradio as gr import torch from backend import (get_message_single, get_message_spam, send_single, send_spam, tokenizer) from defaults import (ADDRESS_BETTERTRANSFORMER, ADDRESS_VANILLA, defaults_bt_single, defaults_bt_spam, defaults_vanilla_single, defaults_vanilla_spam) def dispatch_single( input_model_single, address_input_vanilla, address_input_bettertransformer ): result_vanilla = send_single(input_model_single, address_input_vanilla) result_bettertransformer = send_single( input_model_single, address_input_bettertransformer ) return result_vanilla, result_bettertransformer def dispatch_spam_artif( input_n_spam_artif, sequence_length, padding_ratio, address_input_vanilla, address_input_bettertransformer, ): sequence_length = int(sequence_length) input_n_spam_artif = int(input_n_spam_artif) inp_tokens = torch.randint(tokenizer.vocab_size - 1, (sequence_length,)) + 1 n_pads = max(int(padding_ratio * len(inp_tokens)), 1) inp_tokens[-n_pads:] = 0 inp_tokens[0] = 101 inp_tokens[-n_pads - 1] = 102 attention_mask = torch.zeros((sequence_length,), dtype=torch.int64) attention_mask[:-n_pads] = 1 str_input = json.dumps( { "input_ids": inp_tokens.cpu().tolist(), "attention_mask": attention_mask.cpu().tolist(), "pre_tokenized": True, } ) input_dataset = datasets.Dataset.from_dict( {"sentence": [str_input for _ in range(input_n_spam_artif)]} ) result_vanilla = send_spam(input_dataset, address_input_vanilla) result_bettertransformer = send_spam(input_dataset, address_input_bettertransformer) return result_vanilla, result_bettertransformer TTILE_IMAGE = """
""" with gr.Blocks() as demo: gr.HTML(TTILE_IMAGE) gr.Markdown( "## Speed up inference and support more workload with PyTorch's BetterTransformer 🤗" ) gr.Markdown( """ **The two AWS instances powering this Space are offline (to save us the $$$). Feel free to reproduce using [this backend code](https://github.com/fxmarty/bettertransformer_demo). The example results are from an AWS EC2 g4dn.xlarge instance with a single NVIDIA T4 GPU.** """ ) gr.Markdown( """ Let's try out [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/) + [TorchServe](https://pytorch.org/serve/)! BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks. Depending on your hardware, batch size, sequence length, padding ratio, it can bring large speedups at inference **at no cost in prediction quality**. As a one-liner, you can convert your 🤗 Transformers models to use BetterTransformer thanks to the integration in the [🤗 Optimum](https://github.com/huggingface/optimum) library: ``` from optimum.bettertransformer import BetterTransformer better_model = BetterTransformer.transform(model) ``` This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer. The model used is [`distilbert-base-uncased-finetuned-sst-2-english`](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english), and TorchServe is parametrized to use a maximum batch size of 8. **Beware:** you may be queued in case several persons use the Space at the same time. For more details on the TorchServe implementation and to reproduce, see [this reference code](https://github.com/fxmarty/bettertransformer_demo). For more details on BetterTransformer, check out the [blog post on PyTorch's Medium](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2), and [the Optimum documentation](https://huggingface.co/docs/optimum/bettertransformer/overview)!""" ) gr.Markdown("""## Single input scenario Note: BetterTransformer normally shines with batch size > 1 and some padding. So this is not the best case here. Check out the heavy workload case below as well. """) address_input_vanilla = gr.Textbox( max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False ) address_input_bettertransformer = gr.Textbox( max_lines=1, label="ip bettertransformer", value=ADDRESS_BETTERTRANSFORMER, visible=False, ) input_model_single = gr.Textbox( max_lines=1, label="Text", value="Expectations were low, enjoyment was high. Although the music was not top level, the story was well-paced.", ) btn_single = gr.Button("Send single text request") with gr.Row(): with gr.Column(scale=50): gr.Markdown("### Vanilla Transformers + TorchServe") output_single_vanilla = gr.Markdown( label="Output single vanilla", value=get_message_single(**defaults_vanilla_single), ) with gr.Column(scale=50): gr.Markdown("### BetterTransformer + TorchServe") output_single_bt = gr.Markdown( label="Output single bt", value=get_message_single(**defaults_bt_single) ) btn_single.click( fn=dispatch_single, inputs=[ input_model_single, address_input_vanilla, address_input_bettertransformer, ], outputs=[output_single_vanilla, output_single_bt], ) gr.Markdown( """ **Beware that the end-to-end latency can be impacted by a different ping time between the two servers.** ## Heavy workload scenario """ ) input_n_spam_artif = gr.Number( label="Number of sequences to send", value=80, ) sequence_length = gr.Number( label="Sequence length (in tokens)", value=128, ) padding_ratio = gr.Number( label="Padding ratio (i.e. how much of the input is padding. In the real world when batch size > 1, the token sequence is padded with 0 to have all inputs with the same length.)", value=0.7, ) btn_spam_artif = gr.Button("Spam text requests (using artificial data)") with gr.Row(): with gr.Column(scale=50): gr.Markdown("### Vanilla Transformers + TorchServe") output_spam_vanilla_artif = gr.Markdown( label="Output spam vanilla", value=get_message_spam(**defaults_vanilla_spam), ) with gr.Column(scale=50): gr.Markdown("### BetterTransformer + TorchServe") output_spam_bt_artif = gr.Markdown( label="Output spam bt", value=get_message_spam(**defaults_bt_spam) ) btn_spam_artif.click( fn=dispatch_spam_artif, inputs=[ input_n_spam_artif, sequence_length, padding_ratio, address_input_vanilla, address_input_bettertransformer, ], outputs=[output_spam_vanilla_artif, output_spam_bt_artif], ) demo.queue(concurrency_count=1) demo.launch()