import argparse import time import gradio as gr import torch from diffusers import DiffusionPipeline, UNet2DConditionModel from scheduling_dmd import DMDScheduler parser = argparse.ArgumentParser() parser.add_argument("--unet-path", type='Lykon/dreamshaper-8') parser.add_argument("--model-path", type='aaronb/dreamshaper-8-dmd-kl-only-6kstep') args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") unet = UNet2DConditionModel.from_pretrained(args.unet_path) pipe = DiffusionPipeline.from_pretrained(args.model_path, unet=unet) pipe.scheduler = DMDScheduler.from_config(pipe.scheduler.config) pipe.to(device=device, dtype=torch.float16) def predict(prompt, seed=1231231): generator = torch.manual_seed(seed) last_time = time.time() image = pipe( prompt, num_inference_steps=1, guidance_scale=0.0, generator=generator, ).images[0] print(f"Pipe took {time.time() - last_time} seconds") return image css = """ #container{ margin: 0 auto; max-width: 40rem; } #intro{ max-width: 100%; text-align: center; margin: 0 auto; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="container"): gr.Markdown( """# Distribution Matching Distillation """, elem_id="intro", ) with gr.Row(): with gr.Row(): prompt = gr.Textbox(placeholder="Insert your prompt here:", scale=5, container=False) generate_bt = gr.Button("Generate", scale=1) image = gr.Image(type="filepath") with gr.Accordion("Advanced options", open=False): seed = gr.Slider(randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1) inputs = [prompt, seed] generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False) prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False) seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False) demo.queue(api_open=False) demo.launch(show_api=False)