squaadai / app.py
amazonaws-la's picture
Update app.py
07c0e5d verified
raw
history blame
10.8 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import random
import gradio as gr
import numpy as np
import spaces
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, DiffusionPipeline
from diffusers.utils import load_image
from safety_checker import StableDiffusionSafetyChecker
DESCRIPTION = "# SDXL"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1824"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
ENABLE_REFINER = os.getenv("ENABLE_REFINER", "1") == "1"
ENABLE_USE_LORA = os.getenv("ENABLE_USE_LORA", "1") == "1"
ENABLE_USE_VAE = os.getenv("ENABLE_USE_VAE", "1") == "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU
def generate(
prompt: str,
negative_prompt: str = "",
prompt_2: str = "",
negative_prompt_2: str = "",
use_negative_prompt: bool = False,
use_prompt_2: bool = False,
use_negative_prompt_2: bool = False,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale_base: float = 5.0,
guidance_scale_refiner: float = 5.0,
num_inference_steps_base: int = 25,
num_inference_steps_refiner: int = 25,
use_vae: bool = False,
use_lora: bool = False,
apply_refiner: bool = False,
model = 'SG161222/Realistic_Vision_V6.0_B1_noVAE',
vaecall = 'stabilityai/sd-vae-ft-mse',
lora = 'amazonaws-la/juliette',
url = "https://m.media-amazon.com/images/I/81zPcrN6m+L.jpg",
lora_scale: float = 0.7,
):
if torch.cuda.is_available():
if not use_vae:
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model, torch_dtype=torch.float16)
if use_vae:
vae = AutoencoderKL.from_pretrained(vaecall, torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)
if use_lora:
pipe.load_lora_weights(lora)
pipe.fuse_lora(lora_scale=0.7)
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((1024, 1024))
if ENABLE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload()
else:
pipe.to(device)
if USE_TORCH_COMPILE:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
generator = torch.Generator().manual_seed(seed)
if not use_negative_prompt:
negative_prompt = None # type: ignore
if not use_prompt_2:
prompt_2 = None # type: ignore
if not use_negative_prompt_2:
negative_prompt_2 = None # type: ignore
if not apply_refiner:
return pipe(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_2=prompt_2,
negative_prompt_2=negative_prompt_2,
width=width,
height=height,
guidance_scale=guidance_scale_base,
num_inference_steps=num_inference_steps_base,
generator=generator,
image=init_image,
output_type="pil",
).images[0]
else:
latents = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_2=prompt_2,
negative_prompt_2=negative_prompt_2,
width=width,
height=height,
guidance_scale=guidance_scale_base,
num_inference_steps=num_inference_steps_base,
generator=generator,
output_type="latent",
).images
image = refiner(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_2=prompt_2,
negative_prompt_2=negative_prompt_2,
guidance_scale=guidance_scale_refiner,
num_inference_steps=num_inference_steps_refiner,
image=latents,
generator=generator,
).images[0]
return image
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
model = gr.Text(label='Modelo')
vaecall = gr.Text(label='VAE')
lora = gr.Text(label='LoRA')
lora_scale = gr.Slider(
label="Lora Scale",
minimum=0.01,
maximum=1,
step=0.01,
value=0.7,
)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
prompt_2 = gr.Text(
label="Prompt 2",
max_lines=1,
placeholder="Enter your prompt",
visible=False,
)
negative_prompt_2 = gr.Text(
label="Negative prompt 2",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
use_vae = gr.Checkbox(label='Use VAE', value=False, visible=ENABLE_USE_VAE)
use_lora = gr.Checkbox(label='Use Lora', value=False, visible=ENABLE_USE_LORA)
apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
with gr.Row():
guidance_scale_base = gr.Slider(
label="Guidance scale for base",
minimum=1,
maximum=20,
step=0.1,
value=5.0,
)
num_inference_steps_base = gr.Slider(
label="Number of inference steps for base",
minimum=10,
maximum=100,
step=1,
value=25,
)
with gr.Row(visible=False) as refiner_params:
guidance_scale_refiner = gr.Slider(
label="Guidance scale for refiner",
minimum=1,
maximum=20,
step=0.1,
value=5.0,
)
num_inference_steps_refiner = gr.Slider(
label="Number of inference steps for refiner",
minimum=10,
maximum=100,
step=1,
value=25,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=result,
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
use_negative_prompt.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt,
outputs=negative_prompt,
queue=False,
api_name=False,
)
use_prompt_2.change(
fn=lambda x: gr.update(visible=x),
inputs=use_prompt_2,
outputs=prompt_2,
queue=False,
api_name=False,
)
use_negative_prompt_2.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt_2,
outputs=negative_prompt_2,
queue=False,
api_name=False,
)
use_vae.change(
fn=lambda x: gr.update(visible=x),
inputs=use_vae,
outputs=vaecall,
queue=False,
api_name=False,
)
use_lora.change(
fn=lambda x: gr.update(visible=x),
inputs=use_lora,
outputs=lora,
queue=False,
api_name=False,
)
apply_refiner.change(
fn=lambda x: gr.update(visible=x),
inputs=apply_refiner,
outputs=refiner_params,
queue=False,
api_name=False,
)
gr.on(
triggers=[
prompt.submit,
negative_prompt.submit,
prompt_2.submit,
negative_prompt_2.submit,
run_button.click,
],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=[
prompt,
negative_prompt,
prompt_2,
negative_prompt_2,
use_negative_prompt,
use_prompt_2,
use_negative_prompt_2,
seed,
width,
height,
guidance_scale_base,
guidance_scale_refiner,
num_inference_steps_base,
num_inference_steps_refiner,
use_vae,
use_lora,
apply_refiner,
model,
vaecall,
lora,
lora_scale,
],
outputs=result,
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()