image-to-image / app.py
dgoot's picture
Remove width and height params
9c58be1
import os
import gradio as gr
import spaces
import torch
from diffusers import AutoPipelineForImage2Image, StableDiffusionInstructPix2PixPipeline
from loguru import logger
from PIL import Image
SUPPORTED_MODELS = [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-refiner-1.0",
"timbrooks/instruct-pix2pix",
]
DEFAULT_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
model = os.environ.get("MODEL_ID", DEFAULT_MODEL)
gpu_duration = int(os.environ.get("GPU_DURATION", 60))
def load_pipeline(model):
pipeline_type = (
StableDiffusionInstructPix2PixPipeline
if model == "timbrooks/instruct-pix2pix"
else AutoPipelineForImage2Image
)
return pipeline_type.from_pretrained(
model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
)
logger.debug(f"Loading pipeline: {dict(model=model)}")
pipe = load_pipeline(model).to("cuda")
@logger.catch(reraise=True)
@spaces.GPU(duration=gpu_duration)
def infer(
prompt: str,
init_image: Image.Image,
negative_prompt: str,
strength: float,
num_inference_steps: int,
guidance_scale: float,
progress=gr.Progress(track_tqdm=True),
):
logger.info(
f"Starting image generation: {dict(model=model, prompt=prompt, image=init_image)}"
)
# Downscale the image
init_image.thumbnail((1024, 1024))
additional_args = {
k: v
for k, v in dict(
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).items()
if v
}
logger.debug(f"Generating image: {dict(prompt=prompt, **additional_args)}")
images = pipe(
prompt=prompt,
image=init_image,
negative_prompt=negative_prompt,
**additional_args,
).images
return images[0]
css = """
@media (max-width: 1280px) {
#images-container {
flex-direction: column;
}
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column():
gr.Markdown("# Image-to-Image")
gr.Markdown(f"## Model: `{model}`")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
with gr.Row(elem_id="images-container"):
init_image = gr.Image(label="Initial image", type="pil")
result = gr.Image(label="Result")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
with gr.Row():
strength = gr.Slider(
label="Strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=0,
maximum=100,
step=1,
value=0,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=100.0,
step=0.1,
value=0.0,
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
init_image,
negative_prompt,
strength,
num_inference_steps,
guidance_scale,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()