File size: 8,070 Bytes
ae4c73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3152b9c
 
 
 
 
 
ae4c73e
 
 
 
 
 
3152b9c
 
ae4c73e
3678910
 
 
 
 
 
 
 
 
ae4c73e
 
 
 
 
 
 
a1bfb17
ae4c73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4e61fc
3152b9c
ae4c73e
 
 
 
 
 
 
 
 
 
 
87ea128
3152b9c
ae4c73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3152b9c
 
ae4c73e
 
 
 
 
 
 
 
 
3152b9c
ae4c73e
 
 
3152b9c
ae4c73e
 
 
 
3152b9c
ae4c73e
e4e61fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image, ImageFilter
from typing import List, Tuple

SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
    "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
    "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
    "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
    "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}
aspect_ratios = {
    "21:9": (21, 9),
    "2:1": (2, 1),
    "16:9": (16, 9),
    "5:4": (5, 4),
    "4:3": (4, 3),
    "3:2": (3, 2),
    "1:1": (1, 1),
}
# Function to calculate resolution
def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8):
    if aspect_ratio not in aspect_ratios:
        raise ValueError(f"Invalid aspect ratio: {aspect_ratio}")

    width_multiplier, height_multiplier = aspect_ratios[aspect_ratio]
    ratio = width_multiplier / height_multiplier
    if mode == 'portrait':
        # Swap the ratio for portrait mode
        ratio = 1 / ratio

    height = int((total_pixels / ratio) ** 0.5)
    height -= height % divisibility

    width = int(height * ratio)
    width -= width % divisibility

    while width * height > total_pixels:
        height -= divisibility
        width = int(height * ratio)
        width -= width % divisibility

    return width, height


# Example prompts with ckpt, aspect, and mode
examples = [
    {"prompt": "A futuristic cityscape at sunset", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"},
    {"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
    {"prompt": "A portrait of a robot in the style of Renaissance art", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
    {"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"},    
    {"prompt": "A serene landscape with mountains and a river", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"},
    {"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"}
]
# Define a function to set the example inputs
def set_example(selected_prompt):
    # Find the example that matches the selected prompt
    for example in examples:
        if example["prompt"] == selected_prompt:
            return example["prompt"], example["negative_prompt"], example["ckpt"], example["aspect"], example["mode"]
    return None, None, None, None, None  # Default values if not found

# Check if CUDA is available (GPU support), and set the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the pipeline for the specified device
# For GPU, use torch_dtype=torch.float16 for better performance
if device == "cuda":
    pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
else:
    pipe = StableDiffusionXLPipeline.from_pretrained(base).to(device)

if SAFETY_CHECKER:
    from safety_checker import StableDiffusionSafetyChecker
    from transformers import CLIPFeatureExtractor

    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    ).to(device)
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

def check_nsfw_images(
    images: List[Image.Image]
) -> Tuple[List[Image.Image], List[bool]]:
    # Assuming feature_extractor and safety_checker are defined and initialized elsewhere

    # Convert PIL Images to the format expected by the feature extractor
    # This often involves converting them to tensors, but the exact method 
    # depends on the feature_extractor's requirements
    safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images]

    # Get NSFW concepts for each image
    has_nsfw_concepts = [safety_checker(
        images=[image],
        clip_input=safety_checker_input.pixel_values.to("cuda")
    ) for image, safety_checker_input in zip(images, safety_checker_inputs)]

    # Flatten the has_nsfw_concepts list if it's nested
    has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist]

    return images, has_nsfw_concepts

# Function 
@spaces.GPU(enable_queue=True)
def generate_image(prompt, negative_prompt, ckpt, aspect_ratio, mode):
    width, height = calculate_resolution(aspect_ratio, mode)  # Calculate resolution based on the aspect ratio
    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1] 

    if num_inference_steps==1:
        # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
    else:
        # Ensure sampler uses "trailing" timesteps.
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
        
    pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device))
    results = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height )

    if SAFETY_CHECKER:
        images, has_nsfw_concepts = check_nsfw_images(results.images)
        if any(has_nsfw_concepts):
            gr.Warning("NSFW content detected.")
            # Apply a blur filter to the first image in the results
            blurred_image = images[0].filter(ImageFilter.GaussianBlur(16))  # Adjust the radius as needed
            return blurred_image
        return images[0]
    return results.images[0]



# Gradio Interface
description = """
SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
"""

with gr.Blocks(css="style.css") as demo:
    gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
    gr.Markdown(description)
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
        with gr.Row():
            negative_prompt = gr.Textbox(label='Optional negative prompt:', scale=8)
        with gr.Row():            
            ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
            aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True)
            mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape')  # Mode as a dropdown
            submit = gr.Button(scale=1, variant='primary')
            
    img = gr.Image(label='SDXL-Lightning Generated Image')

    prompt.submit(fn=generate_image,
                 inputs=[prompt, negative_prompt, ckpt, aspect, mode],
                 outputs=img,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, negative_prompt, ckpt, aspect, mode],
                 outputs=img,
                 )
    # Dropdown for selecting examples
    example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples])
    example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, negative_prompt, ckpt, aspect, mode])

demo.queue().launch()