import os import gradio as gr import spaces import torch from diffusers import AutoPipelineForImage2Image, StableDiffusionInstructPix2PixPipeline from loguru import logger from PIL import Image SUPPORTED_MODELS = [ "stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-3-medium-diffusers", "stabilityai/stable-diffusion-xl-refiner-1.0", "timbrooks/instruct-pix2pix", ] DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0" model = os.environ.get("MODEL_ID", DEFAULT_MODEL) gpu_duration = int(os.environ.get("GPU_DURATION", 60)) def load_pipeline(model): pipeline_type = ( StableDiffusionInstructPix2PixPipeline if model == "timbrooks/instruct-pix2pix" else AutoPipelineForImage2Image ) return pipeline_type.from_pretrained( model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) logger.debug(f"Loading pipeline: {dict(model=model)}") pipe = load_pipeline(model).to("cuda") @logger.catch(reraise=True) @spaces.GPU(duration=gpu_duration) def infer( prompt: str, init_image: Image.Image, negative_prompt: str, strength: float, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True), ): logger.info( f"Starting image generation: {dict(model=model, prompt=prompt, image=init_image)}" ) # Downscale the image init_image.thumbnail((1024, 1024)) additional_args = { k: v for k, v in dict( strength=strength, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).items() if v } logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}") images = pipe( prompt=prompt, image=init_image, negative_prompt=negative_prompt, **additional_args, ).images return images[0] css = """ @media (max-width: 1280px) { #images-container { flex-direction: column; } } """ with gr.Blocks(css=css) as demo: with gr.Column(): gr.Markdown("# Image-to-Image") gr.Markdown(f"## Model: `{model}`") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0, variant="primary") with gr.Row(elem_id="images-container"): init_image = gr.Image(label="Initial image", type="pil") result = gr.Image(label="Result") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Text( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", ) with gr.Row(): strength = gr.Slider( label="Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=0, maximum=100, step=1, value=0, ) guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=100.0, step=0.1, value=0.0, ) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, init_image, negative_prompt, strength, num_inference_steps, guidance_scale, ], outputs=[result], ) if __name__ == "__main__": demo.launch()