import gradio as gr from backend import get_message_single, get_message_spam, send_single, send_spam from defaults import ( ADDRESS_BETTERTRANSFORMER, ADDRESS_VANILLA, defaults_bt_single, defaults_bt_spam, defaults_vanilla_single, defaults_vanilla_spam, ) TTILE_IMAGE = """
""" TITLE = """

Speed up your inference and support more workload with PyTorch's BetterTransformer 🤗

""" with gr.Blocks() as demo: gr.HTML(TTILE_IMAGE) gr.HTML(TITLE) gr.Markdown( """ Let's try out TorchServe + BetterTransformer! BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks. As a one-liner, you can convert your 🤗 Transformers models to use BetterTransformer thanks to the [🤗 Optimum](https://huggingface.co/docs/optimum/main/en/index) library: ``` from optimum.bettertransformer import BetterTransformer better_model = BetterTransformer.transform(model) ``` This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer. ## Inference using... """ ) with gr.Row(): with gr.Column(scale=50): gr.Markdown("### Vanilla Transformers + TorchServe") address_input_vanilla = gr.Textbox( max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False ) input_model_vanilla = gr.Textbox( max_lines=1, label="Text", value="Expectations were low, enjoyment was high", ) btn_single_vanilla = gr.Button("Send single text request") output_single_vanilla = gr.Markdown( label="Output single vanilla", value=get_message_single(**defaults_vanilla_single), ) with gr.Column(): with gr.Column(scale=40): input_n_inputs_vanilla = gr.Textbox( max_lines=1, label="Number of inputs", value=8, ) with gr.Column(scale=60): gr.Markdown("") btn_spam_vanilla = gr.Button( "Spam text requests (from sst2 validation set)" ) output_spam_vanilla = gr.Markdown( label="Output spam vanilla", value=get_message_spam(**defaults_vanilla_spam), ) btn_single_vanilla.click( fn=send_single, inputs=[input_model_vanilla, address_input_vanilla], outputs=output_single_vanilla, ) btn_spam_vanilla.click( fn=send_spam, inputs=[address_input_vanilla], outputs=output_spam_vanilla, ) with gr.Column(scale=50): gr.Markdown("### BetterTransformer + TorchServe") address_input_bettertransformer = gr.Textbox( max_lines=1, label="ip bettertransformer", value=ADDRESS_BETTERTRANSFORMER, visible=False, ) input_model_bettertransformer = gr.Textbox( max_lines=1, label="Text", value="Expectations were low, enjoyment was high", ) btn_single_bt = gr.Button("Send single text request") output_single_bt = gr.Markdown( label="Output single bt", value=get_message_single(**defaults_bt_single) ) with gr.Row(): with gr.Column(scale=40): input_n_inputs_bt = gr.Textbox( max_lines=1, label="Number of inputs", value=8, ) with gr.Column(scale=60): gr.Markdown("") btn_spam_bt = gr.Button("Spam text requests (from sst2 validation set)") output_spam_bt = gr.Markdown( label="Output spam bt", value=get_message_spam(**defaults_bt_spam) ) btn_single_bt.click( fn=send_single, inputs=[input_model_bettertransformer, address_input_bettertransformer], outputs=output_single_bt, ) btn_spam_bt.click( fn=send_spam, inputs=[address_input_bettertransformer], outputs=output_spam_bt, ) demo.queue(concurrency_count=1) demo.launch()