erfan-yahoo commited on
Commit
16f6e14
1 Parent(s): ec74248

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1352 -0
pipeline.py ADDED
@@ -0,0 +1,1352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024, Yahoo Research
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
18
+
19
+ import inspect
20
+ import warnings
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import PIL.Image
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
28
+
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import (
34
+ is_accelerate_available,
35
+ is_accelerate_version,
36
+ logging,
37
+ replace_example_docstring,
38
+ )
39
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
40
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
41
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
42
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> # !pip install transformers accelerate
52
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
53
+ >>> from diffusers.utils import load_image
54
+ >>> import numpy as np
55
+ >>> import torch
56
+
57
+ >>> init_image = load_image(
58
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
59
+ ... )
60
+ >>> init_image = init_image.resize((512, 512))
61
+
62
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
63
+
64
+ >>> mask_image = load_image(
65
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
66
+ ... )
67
+ >>> mask_image = mask_image.resize((512, 512))
68
+
69
+
70
+ >>> def make_inpaint_condition(image, image_mask):
71
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
72
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
73
+
74
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
75
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
76
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
77
+ ... image = torch.from_numpy(image)
78
+ ... return image
79
+
80
+
81
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
82
+
83
+ >>> controlnet = ControlNetModel.from_pretrained(
84
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
85
+ ... )
86
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
87
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
88
+ ... )
89
+
90
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
91
+ >>> pipe.enable_model_cpu_offload()
92
+
93
+ >>> # generate image
94
+ >>> image = pipe(
95
+ ... "a handsome man with ray-ban sunglasses",
96
+ ... num_inference_steps=20,
97
+ ... generator=generator,
98
+ ... eta=1.0,
99
+ ... image=init_image,
100
+ ... mask_image=mask_image,
101
+ ... control_image=control_image,
102
+ ... ).images[0]
103
+ ```
104
+ """
105
+
106
+
107
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
108
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
109
+ """
110
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
111
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
112
+ ``image`` and ``1`` for the ``mask``.
113
+
114
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
115
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
116
+
117
+ Args:
118
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
119
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
120
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
121
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
122
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
123
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
124
+
125
+
126
+ Raises:
127
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
128
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
129
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
130
+ (ot the other way around).
131
+
132
+ Returns:
133
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
134
+ dimensions: ``batch x channels x height x width``.
135
+ """
136
+
137
+ if image is None:
138
+ raise ValueError("`image` input cannot be undefined.")
139
+
140
+ if mask is None:
141
+ raise ValueError("`mask_image` input cannot be undefined.")
142
+
143
+ if isinstance(image, torch.Tensor):
144
+ if not isinstance(mask, torch.Tensor):
145
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
146
+
147
+ # Batch single image
148
+ if image.ndim == 3:
149
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
150
+ image = image.unsqueeze(0)
151
+
152
+ # Batch and add channel dim for single mask
153
+ if mask.ndim == 2:
154
+ mask = mask.unsqueeze(0).unsqueeze(0)
155
+
156
+ # Batch single mask or add channel dim
157
+ if mask.ndim == 3:
158
+ # Single batched mask, no channel dim or single mask not batched but channel dim
159
+ if mask.shape[0] == 1:
160
+ mask = mask.unsqueeze(0)
161
+
162
+ # Batched masks no channel dim
163
+ else:
164
+ mask = mask.unsqueeze(1)
165
+
166
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
167
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
168
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
169
+
170
+ # Check image is in [-1, 1]
171
+ if image.min() < -1 or image.max() > 1:
172
+ raise ValueError("Image should be in [-1, 1] range")
173
+
174
+ # Check mask is in [0, 1]
175
+ if mask.min() < 0 or mask.max() > 1:
176
+ raise ValueError("Mask should be in [0, 1] range")
177
+
178
+ # Binarize mask
179
+ mask[mask < 0.5] = 0
180
+ mask[mask >= 0.5] = 1
181
+
182
+ # Image as float32
183
+ image = image.to(dtype=torch.float32)
184
+ elif isinstance(mask, torch.Tensor):
185
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
186
+ else:
187
+ # preprocess image
188
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
189
+ image = [image]
190
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
191
+ # resize all images w.r.t passed height an width
192
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
193
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
194
+ image = np.concatenate(image, axis=0)
195
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
196
+ image = np.concatenate([i[None, :] for i in image], axis=0)
197
+
198
+ image = image.transpose(0, 3, 1, 2)
199
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
200
+
201
+ # preprocess mask
202
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
203
+ mask = [mask]
204
+
205
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
206
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
207
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
208
+ mask = mask.astype(np.float32) / 255.0
209
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
210
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
211
+
212
+ mask[mask < 0.5] = 0
213
+ mask[mask >= 0.5] = 1
214
+ mask = torch.from_numpy(mask)
215
+
216
+ masked_image = image * (mask < 0.5)
217
+
218
+ # n.b. ensure backwards compatibility as old function does not return image
219
+ if return_image:
220
+ return mask, masked_image, image
221
+
222
+ return mask, masked_image
223
+
224
+
225
+ class StableDiffusionControlNetInpaintPipeline(
226
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
227
+ ):
228
+ r"""
229
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
230
+
231
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
232
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
233
+
234
+ In addition the pipeline inherits the following loading methods:
235
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
236
+
237
+ <Tip>
238
+
239
+ This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
240
+ [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
241
+ as well as default text-to-image stable diffusion checkpoints, such as
242
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
243
+ Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
244
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
245
+
246
+ </Tip>
247
+
248
+ Args:
249
+ vae ([`AutoencoderKL`]):
250
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
251
+ text_encoder ([`CLIPTextModel`]):
252
+ Frozen text-encoder. Stable Diffusion uses the text portion of
253
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
254
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
255
+ tokenizer (`CLIPTokenizer`):
256
+ Tokenizer of class
257
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
258
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
259
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
260
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
261
+ as a list, the outputs from each ControlNet are added together to create one combined additional
262
+ conditioning.
263
+ scheduler ([`SchedulerMixin`]):
264
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
265
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
266
+ safety_checker ([`StableDiffusionSafetyChecker`]):
267
+ Classification module that estimates whether generated images could be considered offensive or harmful.
268
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
269
+ feature_extractor ([`CLIPImageProcessor`]):
270
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
271
+ """
272
+ _optional_components = ["safety_checker", "feature_extractor"]
273
+
274
+ def __init__(
275
+ self,
276
+ vae: AutoencoderKL,
277
+ text_encoder: CLIPTextModel,
278
+ tokenizer: CLIPTokenizer,
279
+ unet: UNet2DConditionModel,
280
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
281
+ scheduler: KarrasDiffusionSchedulers,
282
+ safety_checker: StableDiffusionSafetyChecker,
283
+ feature_extractor: CLIPImageProcessor,
284
+ requires_safety_checker: bool = True,
285
+ ):
286
+ super().__init__()
287
+
288
+ if safety_checker is None and requires_safety_checker:
289
+ logger.warning(
290
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
291
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
292
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
293
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
294
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
295
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
296
+ )
297
+
298
+ if safety_checker is not None and feature_extractor is None:
299
+ raise ValueError(
300
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
301
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
302
+ )
303
+
304
+ if isinstance(controlnet, (list, tuple)):
305
+ controlnet = MultiControlNetModel(controlnet)
306
+
307
+ self.register_modules(
308
+ vae=vae,
309
+ text_encoder=text_encoder,
310
+ tokenizer=tokenizer,
311
+ unet=unet,
312
+ controlnet=controlnet,
313
+ scheduler=scheduler,
314
+ safety_checker=safety_checker,
315
+ feature_extractor=feature_extractor,
316
+ )
317
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
318
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
319
+ self.control_image_processor = VaeImageProcessor(
320
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
321
+ )
322
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
323
+
324
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
325
+ def enable_vae_slicing(self):
326
+ r"""
327
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
328
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
329
+ """
330
+ self.vae.enable_slicing()
331
+
332
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
333
+ def disable_vae_slicing(self):
334
+ r"""
335
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
336
+ computing decoding in one step.
337
+ """
338
+ self.vae.disable_slicing()
339
+
340
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
341
+ def enable_vae_tiling(self):
342
+ r"""
343
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
344
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
345
+ processing larger images.
346
+ """
347
+ self.vae.enable_tiling()
348
+
349
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
350
+ def disable_vae_tiling(self):
351
+ r"""
352
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
353
+ computing decoding in one step.
354
+ """
355
+ self.vae.disable_tiling()
356
+
357
+ def enable_model_cpu_offload(self, gpu_id=0):
358
+ r"""
359
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
360
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
361
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
362
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
363
+ """
364
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
365
+ from accelerate import cpu_offload_with_hook
366
+ else:
367
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
368
+
369
+ device = torch.device(f"cuda:{gpu_id}")
370
+
371
+ hook = None
372
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
373
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
374
+
375
+ if self.safety_checker is not None:
376
+ # the safety checker can offload the vae again
377
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
378
+
379
+ # control net hook has be manually offloaded as it alternates with unet
380
+ cpu_offload_with_hook(self.controlnet, device)
381
+
382
+ # We'll offload the last model manually.
383
+ self.final_offload_hook = hook
384
+
385
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
386
+ def _encode_prompt(
387
+ self,
388
+ prompt,
389
+ device,
390
+ num_images_per_prompt,
391
+ do_classifier_free_guidance,
392
+ negative_prompt=None,
393
+ prompt_embeds: Optional[torch.FloatTensor] = None,
394
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
395
+ lora_scale: Optional[float] = None,
396
+ ):
397
+ r"""
398
+ Encodes the prompt into text encoder hidden states.
399
+
400
+ Args:
401
+ prompt (`str` or `List[str]`, *optional*):
402
+ prompt to be encoded
403
+ device: (`torch.device`):
404
+ torch device
405
+ num_images_per_prompt (`int`):
406
+ number of images that should be generated per prompt
407
+ do_classifier_free_guidance (`bool`):
408
+ whether to use classifier free guidance or not
409
+ negative_prompt (`str` or `List[str]`, *optional*):
410
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
411
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
412
+ less than `1`).
413
+ prompt_embeds (`torch.FloatTensor`, *optional*):
414
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
415
+ provided, text embeddings will be generated from `prompt` input argument.
416
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
417
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
418
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
419
+ argument.
420
+ lora_scale (`float`, *optional*):
421
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
422
+ """
423
+ # set lora scale so that monkey patched LoRA
424
+ # function of text encoder can correctly access it
425
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
426
+ self._lora_scale = lora_scale
427
+
428
+ if prompt is not None and isinstance(prompt, str):
429
+ batch_size = 1
430
+ elif prompt is not None and isinstance(prompt, list):
431
+ batch_size = len(prompt)
432
+ else:
433
+ batch_size = prompt_embeds.shape[0]
434
+
435
+ if prompt_embeds is None:
436
+ # textual inversion: procecss multi-vector tokens if necessary
437
+ if isinstance(self, TextualInversionLoaderMixin):
438
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
439
+
440
+ text_inputs = self.tokenizer(
441
+ prompt,
442
+ padding="max_length",
443
+ max_length=self.tokenizer.model_max_length,
444
+ truncation=True,
445
+ return_tensors="pt",
446
+ )
447
+ text_input_ids = text_inputs.input_ids
448
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
449
+
450
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
451
+ text_input_ids, untruncated_ids
452
+ ):
453
+ removed_text = self.tokenizer.batch_decode(
454
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
455
+ )
456
+ logger.warning(
457
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
458
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
459
+ )
460
+
461
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
462
+ attention_mask = text_inputs.attention_mask.to(device)
463
+ else:
464
+ attention_mask = None
465
+
466
+ prompt_embeds = self.text_encoder(
467
+ text_input_ids.to(device),
468
+ attention_mask=attention_mask,
469
+ )
470
+ prompt_embeds = prompt_embeds[0]
471
+
472
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
473
+
474
+ bs_embed, seq_len, _ = prompt_embeds.shape
475
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
476
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
477
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
478
+
479
+ # get unconditional embeddings for classifier free guidance
480
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
481
+ uncond_tokens: List[str]
482
+ if negative_prompt is None:
483
+ uncond_tokens = [""] * batch_size
484
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
485
+ raise TypeError(
486
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
487
+ f" {type(prompt)}."
488
+ )
489
+ elif isinstance(negative_prompt, str):
490
+ uncond_tokens = [negative_prompt]
491
+ elif batch_size != len(negative_prompt):
492
+ raise ValueError(
493
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
494
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
495
+ " the batch size of `prompt`."
496
+ )
497
+ else:
498
+ uncond_tokens = negative_prompt
499
+
500
+ # textual inversion: procecss multi-vector tokens if necessary
501
+ if isinstance(self, TextualInversionLoaderMixin):
502
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
503
+
504
+ max_length = prompt_embeds.shape[1]
505
+ uncond_input = self.tokenizer(
506
+ uncond_tokens,
507
+ padding="max_length",
508
+ max_length=max_length,
509
+ truncation=True,
510
+ return_tensors="pt",
511
+ )
512
+
513
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
514
+ attention_mask = uncond_input.attention_mask.to(device)
515
+ else:
516
+ attention_mask = None
517
+
518
+ negative_prompt_embeds = self.text_encoder(
519
+ uncond_input.input_ids.to(device),
520
+ attention_mask=attention_mask,
521
+ )
522
+ negative_prompt_embeds = negative_prompt_embeds[0]
523
+
524
+ if do_classifier_free_guidance:
525
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
526
+ seq_len = negative_prompt_embeds.shape[1]
527
+
528
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
529
+
530
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
531
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
532
+
533
+ # For classifier free guidance, we need to do two forward passes.
534
+ # Here we concatenate the unconditional and text embeddings into a single batch
535
+ # to avoid doing two forward passes
536
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
537
+
538
+ return prompt_embeds
539
+
540
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
541
+ def run_safety_checker(self, image, device, dtype):
542
+ if self.safety_checker is None:
543
+ has_nsfw_concept = None
544
+ else:
545
+ if torch.is_tensor(image):
546
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
547
+ else:
548
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
549
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
550
+ image, has_nsfw_concept = self.safety_checker(
551
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
552
+ )
553
+ return image, has_nsfw_concept
554
+
555
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
556
+ def decode_latents(self, latents):
557
+ warnings.warn(
558
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
559
+ " use VaeImageProcessor instead",
560
+ FutureWarning,
561
+ )
562
+ latents = 1 / self.vae.config.scaling_factor * latents
563
+ image = self.vae.decode(latents, return_dict=False)[0]
564
+ image = (image / 2 + 0.5).clamp(0, 1)
565
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
566
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
567
+ return image
568
+
569
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
570
+ def prepare_extra_step_kwargs(self, generator, eta):
571
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
572
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
573
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
574
+ # and should be between [0, 1]
575
+
576
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
577
+ extra_step_kwargs = {}
578
+ if accepts_eta:
579
+ extra_step_kwargs["eta"] = eta
580
+
581
+ # check if the scheduler accepts generator
582
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
583
+ if accepts_generator:
584
+ extra_step_kwargs["generator"] = generator
585
+ return extra_step_kwargs
586
+
587
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
588
+ def get_timesteps(self, num_inference_steps, strength, device):
589
+ # get the original timestep using init_timestep
590
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
591
+
592
+ t_start = max(num_inference_steps - init_timestep, 0)
593
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
594
+
595
+ return timesteps, num_inference_steps - t_start
596
+
597
+ def check_inputs(
598
+ self,
599
+ prompt,
600
+ image,
601
+ height,
602
+ width,
603
+ callback_steps,
604
+ negative_prompt=None,
605
+ prompt_embeds=None,
606
+ negative_prompt_embeds=None,
607
+ controlnet_conditioning_scale=1.0,
608
+ control_guidance_start=0.0,
609
+ control_guidance_end=1.0,
610
+ ):
611
+ if height % 8 != 0 or width % 8 != 0:
612
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
613
+
614
+ if (callback_steps is None) or (
615
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
616
+ ):
617
+ raise ValueError(
618
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
619
+ f" {type(callback_steps)}."
620
+ )
621
+
622
+ if prompt is not None and prompt_embeds is not None:
623
+ raise ValueError(
624
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
625
+ " only forward one of the two."
626
+ )
627
+ elif prompt is None and prompt_embeds is None:
628
+ raise ValueError(
629
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
630
+ )
631
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
632
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
633
+
634
+ if negative_prompt is not None and negative_prompt_embeds is not None:
635
+ raise ValueError(
636
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
637
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
638
+ )
639
+
640
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
641
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
642
+ raise ValueError(
643
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
644
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
645
+ f" {negative_prompt_embeds.shape}."
646
+ )
647
+
648
+ # `prompt` needs more sophisticated handling when there are multiple
649
+ # conditionings.
650
+ if isinstance(self.controlnet, MultiControlNetModel):
651
+ if isinstance(prompt, list):
652
+ logger.warning(
653
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
654
+ " prompts. The conditionings will be fixed across the prompts."
655
+ )
656
+
657
+ # Check `image`
658
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
659
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
660
+ )
661
+ if (
662
+ isinstance(self.controlnet, ControlNetModel)
663
+ or is_compiled
664
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
665
+ ):
666
+ self.check_image(image, prompt, prompt_embeds)
667
+ elif (
668
+ isinstance(self.controlnet, MultiControlNetModel)
669
+ or is_compiled
670
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
671
+ ):
672
+ if not isinstance(image, list):
673
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
674
+
675
+ # When `image` is a nested list:
676
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
677
+ elif any(isinstance(i, list) for i in image):
678
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
679
+ elif len(image) != len(self.controlnet.nets):
680
+ raise ValueError(
681
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
682
+ )
683
+
684
+ for image_ in image:
685
+ self.check_image(image_, prompt, prompt_embeds)
686
+ else:
687
+ assert False
688
+
689
+ # Check `controlnet_conditioning_scale`
690
+ if (
691
+ isinstance(self.controlnet, ControlNetModel)
692
+ or is_compiled
693
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
694
+ ):
695
+ if not isinstance(controlnet_conditioning_scale, float):
696
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
697
+ elif (
698
+ isinstance(self.controlnet, MultiControlNetModel)
699
+ or is_compiled
700
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
701
+ ):
702
+ if isinstance(controlnet_conditioning_scale, list):
703
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
704
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
705
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
706
+ self.controlnet.nets
707
+ ):
708
+ raise ValueError(
709
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
710
+ " the same length as the number of controlnets"
711
+ )
712
+ else:
713
+ assert False
714
+
715
+ if len(control_guidance_start) != len(control_guidance_end):
716
+ raise ValueError(
717
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
718
+ )
719
+
720
+ if isinstance(self.controlnet, MultiControlNetModel):
721
+ if len(control_guidance_start) != len(self.controlnet.nets):
722
+ raise ValueError(
723
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
724
+ )
725
+
726
+ for start, end in zip(control_guidance_start, control_guidance_end):
727
+ if start >= end:
728
+ raise ValueError(
729
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
730
+ )
731
+ if start < 0.0:
732
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
733
+ if end > 1.0:
734
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
735
+
736
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
737
+ def check_image(self, image, prompt, prompt_embeds):
738
+ image_is_pil = isinstance(image, PIL.Image.Image)
739
+ image_is_tensor = isinstance(image, torch.Tensor)
740
+ image_is_np = isinstance(image, np.ndarray)
741
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
742
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
743
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
744
+
745
+ if (
746
+ not image_is_pil
747
+ and not image_is_tensor
748
+ and not image_is_np
749
+ and not image_is_pil_list
750
+ and not image_is_tensor_list
751
+ and not image_is_np_list
752
+ ):
753
+ raise TypeError(
754
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
755
+ )
756
+
757
+ if image_is_pil:
758
+ image_batch_size = 1
759
+ else:
760
+ image_batch_size = len(image)
761
+
762
+ if prompt is not None and isinstance(prompt, str):
763
+ prompt_batch_size = 1
764
+ elif prompt is not None and isinstance(prompt, list):
765
+ prompt_batch_size = len(prompt)
766
+ elif prompt_embeds is not None:
767
+ prompt_batch_size = prompt_embeds.shape[0]
768
+
769
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
770
+ raise ValueError(
771
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
772
+ )
773
+
774
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
775
+ def prepare_control_image(
776
+ self,
777
+ image,
778
+ width,
779
+ height,
780
+ batch_size,
781
+ num_images_per_prompt,
782
+ device,
783
+ dtype,
784
+ do_classifier_free_guidance=False,
785
+ guess_mode=False,
786
+ ):
787
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
788
+ image_batch_size = image.shape[0]
789
+
790
+ if image_batch_size == 1:
791
+ repeat_by = batch_size
792
+ else:
793
+ # image batch size is the same as prompt batch size
794
+ repeat_by = num_images_per_prompt
795
+
796
+ image = image.repeat_interleave(repeat_by, dim=0)
797
+
798
+ image = image.to(device=device, dtype=dtype)
799
+
800
+ if do_classifier_free_guidance and not guess_mode:
801
+ image = torch.cat([image] * 2)
802
+
803
+ return image
804
+
805
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
806
+ def prepare_latents(
807
+ self,
808
+ batch_size,
809
+ num_channels_latents,
810
+ height,
811
+ width,
812
+ dtype,
813
+ device,
814
+ generator,
815
+ latents=None,
816
+ image=None,
817
+ timestep=None,
818
+ is_strength_max=True,
819
+ return_noise=False,
820
+ return_image_latents=False,
821
+ ):
822
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
823
+ if isinstance(generator, list) and len(generator) != batch_size:
824
+ raise ValueError(
825
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
826
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
827
+ )
828
+
829
+ if (image is None or timestep is None) and not is_strength_max:
830
+ raise ValueError(
831
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
832
+ "However, either the image or the noise timestep has not been provided."
833
+ )
834
+
835
+ if return_image_latents or (latents is None and not is_strength_max):
836
+ image = image.to(device=device, dtype=dtype)
837
+ image_latents = self._encode_vae_image(image=image, generator=generator)
838
+
839
+ if latents is None:
840
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
841
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
842
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
843
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
844
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
845
+ else:
846
+ noise = latents.to(device)
847
+ latents = noise * self.scheduler.init_noise_sigma
848
+
849
+ outputs = (latents,)
850
+
851
+ if return_noise:
852
+ outputs += (noise,)
853
+
854
+ if return_image_latents:
855
+ outputs += (image_latents,)
856
+
857
+ return outputs
858
+
859
+ def _default_height_width(self, height, width, image):
860
+ # NOTE: It is possible that a list of images have different
861
+ # dimensions for each image, so just checking the first image
862
+ # is not _exactly_ correct, but it is simple.
863
+ while isinstance(image, list):
864
+ image = image[0]
865
+
866
+ if height is None:
867
+ if isinstance(image, PIL.Image.Image):
868
+ height = image.height
869
+ elif isinstance(image, torch.Tensor):
870
+ height = image.shape[2]
871
+
872
+ height = (height // 8) * 8 # round down to nearest multiple of 8
873
+
874
+ if width is None:
875
+ if isinstance(image, PIL.Image.Image):
876
+ width = image.width
877
+ elif isinstance(image, torch.Tensor):
878
+ width = image.shape[3]
879
+
880
+ width = (width // 8) * 8 # round down to nearest multiple of 8
881
+
882
+ return height, width
883
+
884
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
885
+ def prepare_mask_latents(
886
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
887
+ ):
888
+ # resize the mask to latents shape as we concatenate the mask to the latents
889
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
890
+ # and half precision
891
+ mask = torch.nn.functional.interpolate(
892
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
893
+ )
894
+ mask = mask.to(device=device, dtype=dtype)
895
+
896
+ masked_image = masked_image.to(device=device, dtype=dtype)
897
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
898
+
899
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
900
+ if mask.shape[0] < batch_size:
901
+ if not batch_size % mask.shape[0] == 0:
902
+ raise ValueError(
903
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
904
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
905
+ " of masks that you pass is divisible by the total requested batch size."
906
+ )
907
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
908
+ if masked_image_latents.shape[0] < batch_size:
909
+ if not batch_size % masked_image_latents.shape[0] == 0:
910
+ raise ValueError(
911
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
912
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
913
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
914
+ )
915
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
916
+
917
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
918
+ masked_image_latents = (
919
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
920
+ )
921
+
922
+ # aligning device to prevent device errors when concating it with the latent model input
923
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
924
+ return mask, masked_image_latents
925
+
926
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
927
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
928
+ if isinstance(generator, list):
929
+ image_latents = [
930
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
931
+ for i in range(image.shape[0])
932
+ ]
933
+ image_latents = torch.cat(image_latents, dim=0)
934
+ else:
935
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
936
+
937
+ image_latents = self.vae.config.scaling_factor * image_latents
938
+
939
+ return image_latents
940
+
941
+ @torch.no_grad()
942
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
943
+ def __call__(
944
+ self,
945
+ prompt: Union[str, List[str]] = None,
946
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
947
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
948
+ control_image: Union[
949
+ torch.FloatTensor,
950
+ PIL.Image.Image,
951
+ np.ndarray,
952
+ List[torch.FloatTensor],
953
+ List[PIL.Image.Image],
954
+ List[np.ndarray],
955
+ ] = None,
956
+ height: Optional[int] = None,
957
+ width: Optional[int] = None,
958
+ strength: float = 1.0,
959
+ num_inference_steps: int = 50,
960
+ guidance_scale: float = 7.5,
961
+ negative_prompt: Optional[Union[str, List[str]]] = None,
962
+ num_images_per_prompt: Optional[int] = 1,
963
+ eta: float = 0.0,
964
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
965
+ latents: Optional[torch.FloatTensor] = None,
966
+ prompt_embeds: Optional[torch.FloatTensor] = None,
967
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
968
+ output_type: Optional[str] = "pil",
969
+ return_dict: bool = True,
970
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
971
+ callback_steps: int = 1,
972
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
973
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
974
+ guess_mode: bool = False,
975
+ control_guidance_start: Union[float, List[float]] = 0.0,
976
+ control_guidance_end: Union[float, List[float]] = 1.0,
977
+ ):
978
+ r"""
979
+ Function invoked when calling the pipeline for generation.
980
+
981
+ Args:
982
+ prompt (`str` or `List[str]`, *optional*):
983
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
984
+ instead.
985
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
986
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
987
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
988
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
989
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
990
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
991
+ specified in init, images must be passed as a list such that each element of the list can be correctly
992
+ batched for input to a single controlnet.
993
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
994
+ The height in pixels of the generated image.
995
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
996
+ The width in pixels of the generated image.
997
+ strength (`float`, *optional*, defaults to 1.):
998
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
999
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1000
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1001
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1002
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1003
+ portion of the reference `image`.
1004
+ num_inference_steps (`int`, *optional*, defaults to 50):
1005
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1006
+ expense of slower inference.
1007
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1008
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1009
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1010
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1011
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1012
+ usually at the expense of lower image quality.
1013
+ negative_prompt (`str` or `List[str]`, *optional*):
1014
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1015
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1016
+ less than `1`).
1017
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1018
+ The number of images to generate per prompt.
1019
+ eta (`float`, *optional*, defaults to 0.0):
1020
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1021
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1022
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1023
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1024
+ to make generation deterministic.
1025
+ latents (`torch.FloatTensor`, *optional*):
1026
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1027
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1028
+ tensor will ge generated by sampling using the supplied random `generator`.
1029
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1030
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1031
+ provided, text embeddings will be generated from `prompt` input argument.
1032
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1033
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1034
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1035
+ argument.
1036
+ output_type (`str`, *optional*, defaults to `"pil"`):
1037
+ The output format of the generate image. Choose between
1038
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1039
+ return_dict (`bool`, *optional*, defaults to `True`):
1040
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1041
+ plain tuple.
1042
+ callback (`Callable`, *optional*):
1043
+ A function that will be called every `callback_steps` steps during inference. The function will be
1044
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1045
+ callback_steps (`int`, *optional*, defaults to 1):
1046
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1047
+ called at every step.
1048
+ cross_attention_kwargs (`dict`, *optional*):
1049
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1050
+ `self.processor` in
1051
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1052
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
1053
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1054
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1055
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
1056
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
1057
+ guess_mode (`bool`, *optional*, defaults to `False`):
1058
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
1059
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
1060
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1061
+ The percentage of total steps at which the controlnet starts applying.
1062
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1063
+ The percentage of total steps at which the controlnet stops applying.
1064
+
1065
+ Examples:
1066
+
1067
+ Returns:
1068
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1069
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1070
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1071
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1072
+ (nsfw) content, according to the `safety_checker`.
1073
+ """
1074
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1075
+
1076
+ # 0. Default height and width to unet
1077
+ height, width = self._default_height_width(height, width, image)
1078
+
1079
+ # align format for control guidance
1080
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1081
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1082
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1083
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1084
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1085
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1086
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
1087
+ control_guidance_end
1088
+ ]
1089
+
1090
+ # 1. Check inputs. Raise error if not correct
1091
+ self.check_inputs(
1092
+ prompt,
1093
+ control_image,
1094
+ height,
1095
+ width,
1096
+ callback_steps,
1097
+ negative_prompt,
1098
+ prompt_embeds,
1099
+ negative_prompt_embeds,
1100
+ controlnet_conditioning_scale,
1101
+ control_guidance_start,
1102
+ control_guidance_end,
1103
+ )
1104
+
1105
+ # 2. Define call parameters
1106
+ if prompt is not None and isinstance(prompt, str):
1107
+ batch_size = 1
1108
+ elif prompt is not None and isinstance(prompt, list):
1109
+ batch_size = len(prompt)
1110
+ else:
1111
+ batch_size = prompt_embeds.shape[0]
1112
+
1113
+ device = self._execution_device
1114
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1115
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1116
+ # corresponds to doing no classifier free guidance.
1117
+ do_classifier_free_guidance = guidance_scale > 1.0
1118
+
1119
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1120
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1121
+
1122
+ global_pool_conditions = (
1123
+ controlnet.config.global_pool_conditions
1124
+ if isinstance(controlnet, ControlNetModel)
1125
+ else controlnet.nets[0].config.global_pool_conditions
1126
+ )
1127
+ guess_mode = guess_mode or global_pool_conditions
1128
+
1129
+ # 3. Encode input prompt
1130
+ text_encoder_lora_scale = (
1131
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1132
+ )
1133
+ prompt_embeds = self._encode_prompt(
1134
+ prompt,
1135
+ device,
1136
+ num_images_per_prompt,
1137
+ do_classifier_free_guidance,
1138
+ negative_prompt,
1139
+ prompt_embeds=prompt_embeds,
1140
+ negative_prompt_embeds=negative_prompt_embeds,
1141
+ lora_scale=text_encoder_lora_scale,
1142
+ )
1143
+
1144
+ # 4. Prepare image
1145
+ if isinstance(controlnet, ControlNetModel):
1146
+ control_image = self.prepare_control_image(
1147
+ image=control_image,
1148
+ width=width,
1149
+ height=height,
1150
+ batch_size=batch_size * num_images_per_prompt,
1151
+ num_images_per_prompt=num_images_per_prompt,
1152
+ device=device,
1153
+ dtype=controlnet.dtype,
1154
+ do_classifier_free_guidance=do_classifier_free_guidance,
1155
+ guess_mode=guess_mode,
1156
+ )
1157
+ elif isinstance(controlnet, MultiControlNetModel):
1158
+ control_images = []
1159
+
1160
+ for control_image_ in control_image:
1161
+ control_image_ = self.prepare_control_image(
1162
+ image=control_image_,
1163
+ width=width,
1164
+ height=height,
1165
+ batch_size=batch_size * num_images_per_prompt,
1166
+ num_images_per_prompt=num_images_per_prompt,
1167
+ device=device,
1168
+ dtype=controlnet.dtype,
1169
+ do_classifier_free_guidance=do_classifier_free_guidance,
1170
+ guess_mode=guess_mode,
1171
+ )
1172
+
1173
+ control_images.append(control_image_)
1174
+
1175
+ control_image = control_images
1176
+ else:
1177
+ assert False
1178
+
1179
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1180
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
1181
+ image, mask_image, height, width, return_image=True
1182
+ )
1183
+
1184
+ # 5. Prepare timesteps
1185
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1186
+ timesteps, num_inference_steps = self.get_timesteps(
1187
+ num_inference_steps=num_inference_steps, strength=strength, device=device
1188
+ )
1189
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1190
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1191
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1192
+ is_strength_max = strength == 1.0
1193
+
1194
+ # 6. Prepare latent variables
1195
+ num_channels_latents = self.vae.config.latent_channels
1196
+ num_channels_unet = self.unet.config.in_channels
1197
+ return_image_latents = num_channels_unet == 4
1198
+ latents_outputs = self.prepare_latents(
1199
+ batch_size * num_images_per_prompt,
1200
+ num_channels_latents,
1201
+ height,
1202
+ width,
1203
+ prompt_embeds.dtype,
1204
+ device,
1205
+ generator,
1206
+ latents,
1207
+ image=init_image,
1208
+ timestep=latent_timestep,
1209
+ is_strength_max=is_strength_max,
1210
+ return_noise=True,
1211
+ return_image_latents=return_image_latents,
1212
+ )
1213
+
1214
+ if return_image_latents:
1215
+ latents, noise, image_latents = latents_outputs
1216
+ else:
1217
+ latents, noise = latents_outputs
1218
+
1219
+ # 7. Prepare mask latent variables
1220
+ mask, masked_image_latents = self.prepare_mask_latents(
1221
+ mask,
1222
+ masked_image,
1223
+ batch_size * num_images_per_prompt,
1224
+ height,
1225
+ width,
1226
+ prompt_embeds.dtype,
1227
+ device,
1228
+ generator,
1229
+ do_classifier_free_guidance,
1230
+ )
1231
+
1232
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1233
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1234
+
1235
+ # 7.1 Create tensor stating which controlnets to keep
1236
+ controlnet_keep = []
1237
+ for i in range(len(timesteps)):
1238
+ keeps = [
1239
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1240
+ for s, e in zip(control_guidance_start, control_guidance_end)
1241
+ ]
1242
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1243
+
1244
+ # 8. Denoising loop
1245
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1246
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1247
+ for i, t in enumerate(timesteps):
1248
+ # expand the latents if we are doing classifier free guidance
1249
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1250
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1251
+
1252
+ # controlnet(s) inference
1253
+ if guess_mode and do_classifier_free_guidance:
1254
+ # Infer ControlNet only for the conditional batch.
1255
+ control_model_input = latents
1256
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1257
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1258
+ else:
1259
+ control_model_input = latent_model_input
1260
+ controlnet_prompt_embeds = prompt_embeds
1261
+
1262
+ if isinstance(controlnet_keep[i], list):
1263
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1264
+ else:
1265
+ cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
1266
+
1267
+ # predict the noise residual
1268
+ if num_channels_unet == 9:
1269
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1270
+
1271
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1272
+ latent_model_input, #control_model_input,
1273
+ t,
1274
+ encoder_hidden_states=controlnet_prompt_embeds,
1275
+ controlnet_cond=control_image,
1276
+ conditioning_scale=cond_scale,
1277
+ guess_mode=guess_mode,
1278
+ return_dict=False,
1279
+ )
1280
+
1281
+ if guess_mode and do_classifier_free_guidance:
1282
+ # Infered ControlNet only for the conditional batch.
1283
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1284
+ # add 0 to the unconditional batch to keep it unchanged.
1285
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1286
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1287
+
1288
+ noise_pred = self.unet(
1289
+ latent_model_input,
1290
+ t,
1291
+ encoder_hidden_states=prompt_embeds,
1292
+ cross_attention_kwargs=cross_attention_kwargs,
1293
+ down_block_additional_residuals=down_block_res_samples,
1294
+ mid_block_additional_residual=mid_block_res_sample,
1295
+ return_dict=False,
1296
+ )[0]
1297
+
1298
+ # perform guidance
1299
+ if do_classifier_free_guidance:
1300
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1301
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1302
+
1303
+ # compute the previous noisy sample x_t -> x_t-1
1304
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1305
+
1306
+ if num_channels_unet == 4:
1307
+ init_latents_proper = image_latents[:1]
1308
+ init_mask = mask[:1]
1309
+
1310
+ if i < len(timesteps) - 1:
1311
+ noise_timestep = timesteps[i + 1]
1312
+ init_latents_proper = self.scheduler.add_noise(
1313
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1314
+ )
1315
+
1316
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1317
+
1318
+ # call the callback, if provided
1319
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1320
+ progress_bar.update()
1321
+ if callback is not None and i % callback_steps == 0:
1322
+ callback(i, t, latents)
1323
+
1324
+ # If we do sequential model offloading, let's offload unet and controlnet
1325
+ # manually for max memory savings
1326
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1327
+ self.unet.to("cpu")
1328
+ self.controlnet.to("cpu")
1329
+ torch.cuda.empty_cache()
1330
+
1331
+ if not output_type == "latent":
1332
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1333
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1334
+ else:
1335
+ image = latents
1336
+ has_nsfw_concept = None
1337
+
1338
+ if has_nsfw_concept is None:
1339
+ do_denormalize = [True] * image.shape[0]
1340
+ else:
1341
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1342
+
1343
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1344
+
1345
+ # Offload last model to CPU
1346
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1347
+ self.final_offload_hook.offload()
1348
+
1349
+ if not return_dict:
1350
+ return (image, has_nsfw_concept)
1351
+
1352
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)