SDXL-Lightning / app.py
CyranoB's picture
Add Negative Prompt parameter (#1)
3152b9c verified
raw
history blame contribute delete
No virus
8.07 kB
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image, ImageFilter
from typing import List, Tuple
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
"1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
"2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
"4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
"8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}
aspect_ratios = {
"21:9": (21, 9),
"2:1": (2, 1),
"16:9": (16, 9),
"5:4": (5, 4),
"4:3": (4, 3),
"3:2": (3, 2),
"1:1": (1, 1),
}
# Function to calculate resolution
def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8):
if aspect_ratio not in aspect_ratios:
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}")
width_multiplier, height_multiplier = aspect_ratios[aspect_ratio]
ratio = width_multiplier / height_multiplier
if mode == 'portrait':
# Swap the ratio for portrait mode
ratio = 1 / ratio
height = int((total_pixels / ratio) ** 0.5)
height -= height % divisibility
width = int(height * ratio)
width -= width % divisibility
while width * height > total_pixels:
height -= divisibility
width = int(height * ratio)
width -= width % divisibility
return width, height
# Example prompts with ckpt, aspect, and mode
examples = [
{"prompt": "A futuristic cityscape at sunset", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"},
{"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "A portrait of a robot in the style of Renaissance art", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "A serene landscape with mountains and a river", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"},
{"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"}
]
# Define a function to set the example inputs
def set_example(selected_prompt):
# Find the example that matches the selected prompt
for example in examples:
if example["prompt"] == selected_prompt:
return example["prompt"], example["negative_prompt"], example["ckpt"], example["aspect"], example["mode"]
return None, None, None, None, None # Default values if not found
# Check if CUDA is available (GPU support), and set the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the pipeline for the specified device
# For GPU, use torch_dtype=torch.float16 for better performance
if device == "cuda":
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
else:
pipe = StableDiffusionXLPipeline.from_pretrained(base).to(device)
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: List[Image.Image]
) -> Tuple[List[Image.Image], List[bool]]:
# Assuming feature_extractor and safety_checker are defined and initialized elsewhere
# Convert PIL Images to the format expected by the feature extractor
# This often involves converting them to tensors, but the exact method
# depends on the feature_extractor's requirements
safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images]
# Get NSFW concepts for each image
has_nsfw_concepts = [safety_checker(
images=[image],
clip_input=safety_checker_input.pixel_values.to("cuda")
) for image, safety_checker_input in zip(images, safety_checker_inputs)]
# Flatten the has_nsfw_concepts list if it's nested
has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist]
return images, has_nsfw_concepts
# Function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, negative_prompt, ckpt, aspect_ratio, mode):
width, height = calculate_resolution(aspect_ratio, mode) # Calculate resolution based on the aspect ratio
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if num_inference_steps==1:
# Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
else:
# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device))
results = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height )
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
# Apply a blur filter to the first image in the results
blurred_image = images[0].filter(ImageFilter.GaussianBlur(16)) # Adjust the radius as needed
return blurred_image
return images[0]
return results.images[0]
# Gradio Interface
description = """
SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
gr.Markdown(description)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
with gr.Row():
negative_prompt = gr.Textbox(label='Optional negative prompt:', scale=8)
with gr.Row():
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True)
mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') # Mode as a dropdown
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='SDXL-Lightning Generated Image')
prompt.submit(fn=generate_image,
inputs=[prompt, negative_prompt, ckpt, aspect, mode],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, negative_prompt, ckpt, aspect, mode],
outputs=img,
)
# Dropdown for selecting examples
example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples])
example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, negative_prompt, ckpt, aspect, mode])
demo.queue().launch()