Spaces:
Runtime error
Runtime error
add files.
Browse files- app.py +261 -102
- constants.py +0 -123
- requirements.txt +5 -0
app.py
CHANGED
@@ -1,103 +1,262 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
)
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from functools import partial
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers import StableDiffusionPipeline
|
9 |
+
from PIL import Image
|
10 |
+
from torchmetrics.functional.multimodal import clip_score
|
11 |
+
from torchmetrics.image.inception import InceptionScore
|
12 |
+
|
13 |
+
SEED = 0
|
14 |
+
WEIGHT_DTYPE = torch.float16
|
15 |
+
|
16 |
+
TITLE = "Evaluate Schedulers with StableDiffusionPipeline 🧨"
|
17 |
+
DESCRIPTION = """
|
18 |
+
This Space allows you to quantitatively compare [different noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers) with a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).
|
19 |
+
|
20 |
+
One of the applications of this Space could be to evaluate different schedulers for a certain Stable Diffusion checkpoint for a fixed number of inference steps.
|
21 |
+
|
22 |
+
Here's how it works:
|
23 |
+
|
24 |
+
* The evaluator first sets a seed and then generates the initial noise which is passed as the initial latent to start the image generation process. It is done to ensure fair comparison.
|
25 |
+
* This initial latent is used every time the pipeline is run (with different schedulers).
|
26 |
+
* To quantify the quality of the generated images we use:
|
27 |
+
* [Inception Score](https://en.wikipedia.org/wiki/Inception_score)
|
28 |
+
* [Clip Score](https://arxiv.org/abs/2104.08718)
|
29 |
+
|
30 |
+
**Notes**:
|
31 |
+
|
32 |
+
* The default scheduler associated with the provided checkpoint is always used for reporting the scores.
|
33 |
+
* Increasing both the number of images per prompt and the number of inference steps could quickly build up the inference queue and thus
|
34 |
+
resulting in slowdowns.
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
inception_score_fn = InceptionScore(normalize=True)
|
39 |
+
torch.manual_seed(SEED)
|
40 |
+
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
|
41 |
+
|
42 |
+
|
43 |
+
def make_grid(images, rows, cols):
|
44 |
+
w, h = images[0].size
|
45 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
46 |
+
for i, image in enumerate(images):
|
47 |
+
grid.paste(image, box=(i % cols * w, i // cols * h))
|
48 |
+
return grid
|
49 |
+
|
50 |
+
|
51 |
+
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py#L814
|
52 |
+
def numpy_to_pil(images):
|
53 |
+
"""
|
54 |
+
Convert a numpy image or a batch of images to a PIL image.
|
55 |
+
"""
|
56 |
+
if images.ndim == 3:
|
57 |
+
images = images[None, ...]
|
58 |
+
images = (images * 255).round().astype("uint8")
|
59 |
+
if images.shape[-1] == 1:
|
60 |
+
# special case for grayscale (single channel) images
|
61 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
62 |
+
else:
|
63 |
+
pil_images = [Image.fromarray(image) for image in images]
|
64 |
+
|
65 |
+
return pil_images
|
66 |
+
|
67 |
+
|
68 |
+
def prepare_report(scheduler_name: str, results: dict):
|
69 |
+
image_grid = results["images"]
|
70 |
+
scores = results["scores"]
|
71 |
+
img_str = ""
|
72 |
+
|
73 |
+
image_name = f"{scheduler_name}_images.png"
|
74 |
+
image_grid.save(image_name)
|
75 |
+
img_str = img_str = f"![img_grid_{scheduler_name}](/file=./{image_name})\n"
|
76 |
+
|
77 |
+
report_str = f"""
|
78 |
+
\n\n## {scheduler_name}
|
79 |
+
|
80 |
+
### Sample images
|
81 |
+
|
82 |
+
{img_str}
|
83 |
+
|
84 |
+
### Scores
|
85 |
+
|
86 |
+
{scores}
|
87 |
+
\n\n
|
88 |
+
"""
|
89 |
+
|
90 |
+
return report_str
|
91 |
+
|
92 |
+
|
93 |
+
def initialize_pipeline(checkpoint: str):
|
94 |
+
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
95 |
+
checkpoint, torch_dtype=WEIGHT_DTYPE
|
96 |
+
)
|
97 |
+
sd_pipe = sd_pipe.to("cuda")
|
98 |
+
original_scheduler_config = sd_pipe.scheduler.config
|
99 |
+
return sd_pipe, original_scheduler_config
|
100 |
+
|
101 |
+
|
102 |
+
def get_scheduler(scheduler_name: str):
|
103 |
+
schedulers_lib = importlib.import_module("diffusers", package="schedulers")
|
104 |
+
scheduler_abs = getattr(schedulers_lib, scheduler_name)
|
105 |
+
|
106 |
+
return scheduler_abs
|
107 |
+
|
108 |
+
|
109 |
+
def get_latents(num_images_per_prompt: int, seed=SEED):
|
110 |
+
generator = torch.manual_seed(seed)
|
111 |
+
latents = np.random.RandomState(seed).standard_normal(
|
112 |
+
(num_images_per_prompt, 4, 64, 64)
|
113 |
)
|
114 |
+
latents = torch.from_numpy(latents).to(device="cuda", dtype=WEIGHT_DTYPE)
|
115 |
+
return latents
|
116 |
+
|
117 |
+
|
118 |
+
def compute_metrics(images: np.ndarray, prompts: List[str]):
|
119 |
+
inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2))
|
120 |
+
inception_score = inception_score_fn.compute()
|
121 |
+
|
122 |
+
images_int = (images * 255).astype("uint8")
|
123 |
+
clip_score = clip_score_fn(
|
124 |
+
torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts
|
125 |
+
).detach()
|
126 |
+
return {
|
127 |
+
"inception_score (⬆️)": {
|
128 |
+
"mean": round(float(inception_score[0]), 4),
|
129 |
+
"std": round(float(inception_score[1]), 4),
|
130 |
+
},
|
131 |
+
"clip_score (⬆️)": round(float(clip_score), 4),
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
def run(
|
136 |
+
prompt: str,
|
137 |
+
num_images_per_prompt: int,
|
138 |
+
num_inference_steps: int,
|
139 |
+
checkpoint: str,
|
140 |
+
schedulers_to_test: List[str],
|
141 |
+
):
|
142 |
+
all_images = {}
|
143 |
+
|
144 |
+
sd_pipeline, original_scheduler_config = initialize_pipeline(checkpoint)
|
145 |
+
latents = get_latents(num_images_per_prompt)
|
146 |
+
prompts = [prompt] * num_images_per_prompt
|
147 |
+
|
148 |
+
images = sd_pipeline(
|
149 |
+
prompts,
|
150 |
+
latents=latents,
|
151 |
+
num_inference_steps=num_inference_steps,
|
152 |
+
output_type="numpy",
|
153 |
+
).images
|
154 |
+
original_scheduler_name = original_scheduler_config._class_name
|
155 |
+
|
156 |
+
all_images.update(
|
157 |
+
{
|
158 |
+
original_scheduler_name: {
|
159 |
+
"images": make_grid(numpy_to_pil(images), 1, num_images_per_prompt),
|
160 |
+
"scores": compute_metrics(images, prompts),
|
161 |
+
}
|
162 |
+
}
|
163 |
+
)
|
164 |
+
# print("First scheduler complete.")
|
165 |
+
|
166 |
+
for scheduler_name in schedulers_to_test:
|
167 |
+
if scheduler_name == original_scheduler_name:
|
168 |
+
continue
|
169 |
+
scheduler_cls = get_scheduler(scheduler_name)
|
170 |
+
current_scheduler = scheduler_cls.from_config(original_scheduler_config)
|
171 |
+
sd_pipeline.scheduler = current_scheduler
|
172 |
+
|
173 |
+
cur_scheduler_images = sd_pipeline(
|
174 |
+
prompts, num_inference_steps=num_inference_steps, output_type="numpy"
|
175 |
+
).images
|
176 |
+
all_images.update(
|
177 |
+
{
|
178 |
+
scheduler_name: {
|
179 |
+
"images": make_grid(
|
180 |
+
numpy_to_pil(cur_scheduler_images), 1, num_images_per_prompt
|
181 |
+
),
|
182 |
+
"scores": compute_metrics(cur_scheduler_images, prompts),
|
183 |
+
}
|
184 |
+
}
|
185 |
+
)
|
186 |
+
# print(f"{scheduler_name} complete.")
|
187 |
+
|
188 |
+
output_str = ""
|
189 |
+
for scheduler_name in all_images:
|
190 |
+
# print(f"scheduler_name: {scheduler_name}")
|
191 |
+
output_str += prepare_report(scheduler_name, all_images[scheduler_name])
|
192 |
+
# print(output_str)
|
193 |
+
return output_str
|
194 |
+
|
195 |
+
|
196 |
+
demo = gr.Interface(
|
197 |
+
run,
|
198 |
+
inputs=[
|
199 |
+
gr.Text(max_lines=1, placeholder="a painting of a dog"),
|
200 |
+
gr.Slider(3, 10, value=3, step=1),
|
201 |
+
gr.Slider(10, 100, value=50, step=1),
|
202 |
+
gr.Dropdown(
|
203 |
+
[
|
204 |
+
"CompVis/stable-diffusion-v1-4",
|
205 |
+
"runwayml/stable-diffusion-v1-5",
|
206 |
+
"stabilityai/stable-diffusion-2-base",
|
207 |
+
],
|
208 |
+
value="CompVis/stable-diffusion-v1-4",
|
209 |
+
multiselect=False,
|
210 |
+
interactive=True,
|
211 |
+
),
|
212 |
+
gr.Dropdown(
|
213 |
+
[
|
214 |
+
"EulerDiscreteScheduler",
|
215 |
+
"PNDMScheduler",
|
216 |
+
"LMSDiscreteScheduler",
|
217 |
+
"DPMSolverMultistepScheduler",
|
218 |
+
"DDIMScheduler",
|
219 |
+
],
|
220 |
+
value=["LMSDiscreteScheduler"],
|
221 |
+
multiselect=True,
|
222 |
+
),
|
223 |
+
],
|
224 |
+
outputs=[gr.Markdown().style()],
|
225 |
+
title=TITLE,
|
226 |
+
description=DESCRIPTION,
|
227 |
+
allow_flagging=False,
|
228 |
+
)
|
229 |
+
demo.launch()
|
230 |
+
|
231 |
+
with gr.Blocks() as demo:
|
232 |
+
with gr.Row():
|
233 |
+
with gr.Column():
|
234 |
+
prompt = gr.Text(max_lines=1, placeholder="a painting of a dog")
|
235 |
+
num_images_per_prompt = gr.Slider(3, 10, value=3, step=1)
|
236 |
+
num_inference_steps = gr.Slider(10, 100, value=50, step=1)
|
237 |
+
model_ckpt = gr.Dropdown(
|
238 |
+
[
|
239 |
+
"CompVis/stable-diffusion-v1-4",
|
240 |
+
"runwayml/stable-diffusion-v1-5",
|
241 |
+
"stabilityai/stable-diffusion-2-base",
|
242 |
+
"Other"
|
243 |
+
],
|
244 |
+
value="CompVis/stable-diffusion-v1-4",
|
245 |
+
multiselect=False,
|
246 |
+
interactive=True,
|
247 |
+
)
|
248 |
+
other_finedtuned_checkpoints = gr.Textbox(visible=False)
|
249 |
+
model_ckpt.change(lambda x: gr.Dropdown.update(visible=x=="Other"), model_ckpt, other_finedtuned_checkpoints)
|
250 |
+
schedulers_to_test = gr.Dropdown(
|
251 |
+
[
|
252 |
+
"EulerDiscreteScheduler",
|
253 |
+
"PNDMScheduler",
|
254 |
+
"LMSDiscreteScheduler",
|
255 |
+
"DPMSolverMultistepScheduler",
|
256 |
+
"DDIMScheduler",
|
257 |
+
],
|
258 |
+
value=["LMSDiscreteScheduler"],
|
259 |
+
multiselect=True,
|
260 |
+
)
|
261 |
+
|
262 |
+
demo.launch()
|
constants.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
css = """
|
2 |
-
.gradio-container {
|
3 |
-
font-family: 'IBM Plex Sans', sans-serif;
|
4 |
-
}
|
5 |
-
.gr-button {
|
6 |
-
color: white;
|
7 |
-
border-color: black;
|
8 |
-
background: black;
|
9 |
-
}
|
10 |
-
input[type='range'] {
|
11 |
-
accent-color: black;
|
12 |
-
}
|
13 |
-
.dark input[type='range'] {
|
14 |
-
accent-color: #dfdfdf;
|
15 |
-
}
|
16 |
-
.container {
|
17 |
-
max-width: 730px;
|
18 |
-
margin: auto;
|
19 |
-
padding-top: 1.5rem;
|
20 |
-
}
|
21 |
-
#gallery {
|
22 |
-
min-height: 22rem;
|
23 |
-
margin-bottom: 15px;
|
24 |
-
margin-left: auto;
|
25 |
-
margin-right: auto;
|
26 |
-
border-bottom-right-radius: .5rem !important;
|
27 |
-
border-bottom-left-radius: .5rem !important;
|
28 |
-
}
|
29 |
-
#gallery>div>.h-full {
|
30 |
-
min-height: 20rem;
|
31 |
-
}
|
32 |
-
.details:hover {
|
33 |
-
text-decoration: underline;
|
34 |
-
}
|
35 |
-
.gr-button {
|
36 |
-
white-space: nowrap;
|
37 |
-
}
|
38 |
-
.gr-button:focus {
|
39 |
-
border-color: rgb(147 197 253 / var(--tw-border-opacity));
|
40 |
-
outline: none;
|
41 |
-
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
|
42 |
-
--tw-border-opacity: 1;
|
43 |
-
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
|
44 |
-
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
|
45 |
-
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
|
46 |
-
--tw-ring-opacity: .5;
|
47 |
-
}
|
48 |
-
#advanced-btn {
|
49 |
-
font-size: .7rem !important;
|
50 |
-
line-height: 19px;
|
51 |
-
margin-top: 12px;
|
52 |
-
margin-bottom: 12px;
|
53 |
-
padding: 2px 8px;
|
54 |
-
border-radius: 14px !important;
|
55 |
-
}
|
56 |
-
#advanced-options {
|
57 |
-
display: none;
|
58 |
-
margin-bottom: 20px;
|
59 |
-
}
|
60 |
-
.footer {
|
61 |
-
margin-bottom: 45px;
|
62 |
-
margin-top: 35px;
|
63 |
-
text-align: center;
|
64 |
-
border-bottom: 1px solid #e5e5e5;
|
65 |
-
}
|
66 |
-
.footer>p {
|
67 |
-
font-size: .8rem;
|
68 |
-
display: inline-block;
|
69 |
-
padding: 0 10px;
|
70 |
-
transform: translateY(10px);
|
71 |
-
background: white;
|
72 |
-
}
|
73 |
-
.dark .footer {
|
74 |
-
border-color: #303030;
|
75 |
-
}
|
76 |
-
.dark .footer>p {
|
77 |
-
background: #0b0f19;
|
78 |
-
}
|
79 |
-
.acknowledgments h4{
|
80 |
-
margin: 1.25em 0 .25em 0;
|
81 |
-
font-weight: bold;
|
82 |
-
font-size: 115%;
|
83 |
-
}
|
84 |
-
.animate-spin {
|
85 |
-
animation: spin 1s linear infinite;
|
86 |
-
}
|
87 |
-
@keyframes spin {
|
88 |
-
from {
|
89 |
-
transform: rotate(0deg);
|
90 |
-
}
|
91 |
-
to {
|
92 |
-
transform: rotate(360deg);
|
93 |
-
}
|
94 |
-
}
|
95 |
-
#share-btn-container {
|
96 |
-
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
|
97 |
-
margin-top: 10px;
|
98 |
-
margin-left: auto;
|
99 |
-
}
|
100 |
-
#share-btn {
|
101 |
-
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
|
102 |
-
}
|
103 |
-
#share-btn * {
|
104 |
-
all: unset;
|
105 |
-
}
|
106 |
-
#share-btn-container div:nth-child(-n+2){
|
107 |
-
width: auto !important;
|
108 |
-
min-height: 0px !important;
|
109 |
-
}
|
110 |
-
#share-btn-container .wrap {
|
111 |
-
display: none !important;
|
112 |
-
}
|
113 |
-
|
114 |
-
.gr-form{
|
115 |
-
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
|
116 |
-
}
|
117 |
-
#prompt-container{
|
118 |
-
gap: 0;
|
119 |
-
}
|
120 |
-
#prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
|
121 |
-
#component-16{border-top-width: 1px!important;margin-top: 1em}
|
122 |
-
.image_duplication{position: absolute; width: 100px; left: 50px}
|
123 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchmetrics[image]
|
2 |
+
transformers
|
3 |
+
diffusers
|
4 |
+
accelerate
|
5 |
+
numpy
|