Hei-Ha commited on
Commit
76cc201
1 Parent(s): d78e76c
Files changed (2) hide show
  1. app.py +106 -25
  2. requirements.txt +5 -0
app.py CHANGED
@@ -1,30 +1,111 @@
1
- # import gradio as gr
2
- # import torch
3
- # from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
4
- # from huggingface_hub import hf_hub_download
5
- # from safetensors.torch import load_file
6
- #
7
- # base = "stabilityai/stable-diffusion-xl-base-1.0"
8
- # repo = "ByteDance/SDXL-Lightning"
9
- # ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
10
- #
11
- # # Load model.
12
- # unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
13
- # unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
14
- # pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
15
- #
16
- # # Ensure sampler uses "trailing" timesteps.
17
- # pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
18
- #
19
- # # Ensure using the same inference steps as the loaded model and CFG set to 0.
20
- # pipe("A girl smiling", num_inference_steps=4, guidance_scale=0).images[0].save("output.png")
21
 
 
22
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
- import gradio as gr
26
- def greet(name):
27
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
30
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ import spaces
7
+ import os
8
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
11
 
12
+ # Constants
13
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
14
+ repo = "ByteDance/SDXL-Lightning"
15
+ checkpoints = {
16
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
+ }
21
 
22
 
23
+ # Ensure model and scheduler are initialized in GPU-enabled function
24
+ if torch.cuda.is_available():
25
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
26
+
27
+ if SAFETY_CHECKER:
28
+ from safety_checker import StableDiffusionSafetyChecker
29
+ from transformers import CLIPFeatureExtractor
30
+
31
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
32
+ "CompVis/stable-diffusion-safety-checker"
33
+ ).to("cuda")
34
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
35
+ "openai/clip-vit-base-patch32"
36
+ )
37
+
38
+ def check_nsfw_images(
39
+ images: list[Image.Image],
40
+ ) -> tuple[list[Image.Image], list[bool]]:
41
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
42
+ has_nsfw_concepts = safety_checker(
43
+ images=[images],
44
+ clip_input=safety_checker_input.pixel_values.to("cuda")
45
+ )
46
+
47
+ return images, has_nsfw_concepts
48
+
49
+ # Function
50
+ @spaces.GPU(enable_queue=True)
51
+ def generate_image(prompt, ckpt):
52
+
53
+ checkpoint = checkpoints[ckpt][0]
54
+ num_inference_steps = checkpoints[ckpt][1]
55
+
56
+ if num_inference_steps==1:
57
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
58
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
59
+ else:
60
+ # Ensure sampler uses "trailing" timesteps.
61
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
62
 
63
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
64
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
+
66
+ if SAFETY_CHECKER:
67
+ images, has_nsfw_concepts = check_nsfw_images(results.images)
68
+ if any(has_nsfw_concepts):
69
+ gr.Warning("NSFW content detected.")
70
+ return Image.new("RGB", (512, 512))
71
+ return images[0]
72
+ return results.images[0]
73
+
74
+
75
+
76
+ # Gradio Interface
77
+ description = """
78
+ This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
79
+ As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
80
+ """
81
+
82
+ with gr.Blocks(css="style.css") as demo:
83
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
84
+ gr.Markdown(description)
85
+ with gr.Group():
86
+ with gr.Row():
87
+ prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
88
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
89
+ submit = gr.Button(scale=1, variant='primary')
90
+ img = gr.Image(label='SDXL-Lightning Generated Image')
91
+
92
+ prompt.submit(fn=generate_image,
93
+ inputs=[prompt, ckpt],
94
+ outputs=img,
95
+ )
96
+ submit.click(fn=generate_image,
97
+ inputs=[prompt, ckpt],
98
+ outputs=img,
99
+ )
100
+
101
+ demo.queue().launch()
102
+
103
+
104
+
105
+
106
+ # import gradio as gr
107
+ # def greet(name):
108
+ # return "Hello " + name + "!!"
109
+ #
110
+ # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
111
+ # iface.launch()
requirements.txt CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ torch
4
+ accelerate
5
+ gradio