File size: 8,141 Bytes
4eea26f
 
 
 
 
 
 
873e677
 
 
 
 
4eea26f
 
 
 
b82f6a5
873e677
3bb8fb9
b82f6a5
81381bd
873e677
 
 
b82f6a5
 
 
 
 
873e677
 
0f357e9
81381bd
0f357e9
b82f6a5
 
873e677
 
4eea26f
 
 
 
 
 
 
 
 
 
 
0f357e9
4eea26f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873e677
 
 
 
 
4eea26f
873e677
 
 
 
 
 
 
 
4eea26f
873e677
4eea26f
873e677
4eea26f
873e677
 
 
4eea26f
 
 
 
873e677
4eea26f
873e677
 
 
 
 
 
4eea26f
 
873e677
 
 
 
4eea26f
 
 
 
 
 
 
873e677
4eea26f
 
 
873e677
 
4eea26f
873e677
4eea26f
 
 
873e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eea26f
 
 
873e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4eea26f
873e677
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import importlib
from typing import List

import gradio as gr
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

from image_utils import make_grid, numpy_to_pil
from metrics_utils import compute_main_metrics, compute_psnr_or_ssim
from report_utils import add_psnr_ssim_to_report, prepare_report

SEED = 0
WEIGHT_DTYPE = torch.float16

TITLE = "Evaluate Schedulers with StableDiffusionPipeline 🧨"
ABSTRACT = """
This Space allows you to quantitatively compare [different noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers) with a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).

One of the applications of this Space could be to evaluate different schedulers for a certain Stable Diffusion checkpoint for a fixed number of inference steps.
"""
DESCRIPTION = """
#### Hoes does it work?
* The evaluator first sets a seed and then generates the initial noise which is passed as the initial latent to start the image generation process. It is done to ensure fair comparison.
* This initial latent is used every time the pipeline is run (with different schedulers).
* To quantify the quality of the generated images we use:
    * [Inception Score](https://en.wikipedia.org/wiki/Inception_score)
    * [Clip Score](https://arxiv.org/abs/2104.08718)
#### Notes
* When selecting a model checkpoint, if you select "Other" you will have the option to provide a custom Stable Diffusion checkpoint.
* The default scheduler associated with the provided checkpoint is always used for reporting the scores.
* Increasing both the number of images per prompt and the number of inference steps could quickly build up the inference queue and thus
resulting in slowdowns.
"""

psnr_fn = PeakSignalNoiseRatio()
ssim_fn = StructuralSimilarityIndexMeasure()


def initialize_pipeline(checkpoint: str):
    sd_pipe = StableDiffusionPipeline.from_pretrained(
        checkpoint, torch_dtype=WEIGHT_DTYPE
    )
    sd_pipe = sd_pipe.to("cuda")
    original_scheduler_config = sd_pipe.scheduler.config
    return sd_pipe, original_scheduler_config


def get_scheduler(scheduler_name: str):
    schedulers_lib = importlib.import_module("diffusers", package="schedulers")
    scheduler_abs = getattr(schedulers_lib, scheduler_name)

    return scheduler_abs


def get_latents(num_images_per_prompt: int, seed=SEED):
    generator = torch.manual_seed(seed)
    latents = np.random.RandomState(seed).standard_normal(
        (num_images_per_prompt, 4, 64, 64)
    )
    latents = torch.from_numpy(latents).to(device="cuda", dtype=WEIGHT_DTYPE)
    return latents


def run(
    prompt: str,
    num_images_per_prompt: int,
    num_inference_steps: int,
    checkpoint: str,
    other_finedtuned_checkpoints: str = None,
    schedulers_to_test: List[str] = None,
    ssim: bool = False,
    psnr: bool = False,
    progress=gr.Progress(),
):
    progress(0, desc="Starting...")

    if checkpoint == "Other" and other_finedtuned_checkpoints == "":
        return "❌ No legit checkpoint provided ❌"

    elif checkpoint == "Other":
        checkpoint = other_finedtuned_checkpoints

    all_images = {}
    scheduler_images = {}

    # Set up the pipeline
    sd_pipeline, original_scheduler_config = initialize_pipeline(checkpoint)
    sd_pipeline.set_progress_bar_config(disable=True)

    # Prepare latents to start generation and the prompts.
    latents = get_latents(num_images_per_prompt)
    prompts = [prompt] * num_images_per_prompt

    original_scheduler_name = original_scheduler_config._class_name
    schedulers_to_test.append(original_scheduler_name)

    # Start generating the images and computing their scores.
    for scheduler_name in progress.tqdm(schedulers_to_test):
        if scheduler_name != original_scheduler_name:
            scheduler_cls = get_scheduler(scheduler_name)
            current_scheduler = scheduler_cls.from_config(original_scheduler_config)
            sd_pipeline.scheduler = current_scheduler

        cur_scheduler_images = sd_pipeline(
            prompts,
            latents=latents,
            num_inference_steps=num_inference_steps,
            output_type="numpy",
        ).images
        all_images.update(
            {
                scheduler_name: {
                    "images": make_grid(
                        numpy_to_pil(cur_scheduler_images), 1, num_images_per_prompt
                    ),
                    "scores": compute_main_metrics(cur_scheduler_images, prompts),
                }
            }
        )
        scheduler_images.update({scheduler_name: cur_scheduler_images})
        torch.cuda.empty_cache()

    # Prepare output report.
    output_str = ""
    for scheduler_name in all_images:
        output_str += prepare_report(scheduler_name, all_images[scheduler_name])

    # Append PSNR or SSIM if needed.
    if len(schedulers_to_test) > 1:
        ssim_scores = psnr_scores = None
        if ssim:
            ssim_scores = compute_psnr_or_ssim(
                ssim_fn, scheduler_images, original_scheduler_name
            )
        if psnr:
            psnr_scores = compute_psnr_or_ssim(
                psnr_fn, scheduler_images, original_scheduler_name
            )

    if len(schedulers_to_test) > 1:
        ssim_psnr_str = add_psnr_ssim_to_report(
            original_scheduler_name, ssim_scores, psnr_scores
        )
        if ssim_psnr_str != "":
            output_str += ssim_psnr_str

    return output_str


with gr.Blocks(title="Scheduler Evaluation") as demo:
    gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")

    with gr.Row():
        with gr.Column():
            prompt = gr.Text(
                max_lines=1, placeholder="a painting of a dog", label="prompt"
            )
            num_images_per_prompt = gr.Slider(
                3, 10, value=3, step=1, label="num_images_per_prompt"
            )
            num_inference_steps = gr.Slider(
                10, 100, value=50, step=1, label="num_inference_steps"
            )
            model_ckpt = gr.Dropdown(
                [
                    "CompVis/stable-diffusion-v1-4",
                    "runwayml/stable-diffusion-v1-5",
                    "stabilityai/stable-diffusion-2-base",
                    "Other",
                ],
                value="CompVis/stable-diffusion-v1-4",
                multiselect=False,
                interactive=True,
                label="model_ckpt",
            )
            other_finedtuned_checkpoints = gr.Textbox(
                visible=False,
                interactive=True,
                placeholder="valhalla/sd-pokemon-model",
                label="custom_checkpoint",
            )
            model_ckpt.change(
                lambda x: gr.Dropdown.update(visible=x == "Other"),
                model_ckpt,
                other_finedtuned_checkpoints,
            )
            schedulers_to_test = gr.Dropdown(
                [
                    "EulerDiscreteScheduler",
                    "PNDMScheduler",
                    "LMSDiscreteScheduler",
                    "DPMSolverMultistepScheduler",
                    "DDIMScheduler",
                ],
                value=["LMSDiscreteScheduler"],
                multiselect=True,
                label="schedulers_to_test",
            )
            ssim = gr.Checkbox(label="Compute SSIM")
            psnr = gr.Checkbox(label="Compute PSNR")
            evaluation_button = gr.Button(value="Submit")

        with gr.Column():
            report = gr.Markdown(label="Evaluation Report").style()

        evaluation_button.click(
            run,
            inputs=[
                prompt,
                num_images_per_prompt,
                num_inference_steps,
                model_ckpt,
                other_finedtuned_checkpoints,
                schedulers_to_test,
                ssim,
                psnr,
            ],
            outputs=report,
        )

    gr.Markdown(f"{DESCRIPTION}")

demo.queue().launch(debug=True)