nyanko7 commited on
Commit
dc5ed5b
1 Parent(s): 9f25d0a

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1613 -0
pipeline.py ADDED
@@ -0,0 +1,1613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation of StableDiffusionXLSEGPipeline
2
+
3
+ import inspect
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from packaging import version
10
+
11
+ from transformers import (
12
+ CLIPImageProcessor,
13
+ CLIPTextModel,
14
+ CLIPTextModelWithProjection,
15
+ CLIPTokenizer,
16
+ CLIPVisionModelWithProjection,
17
+ )
18
+
19
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
20
+ from diffusers.loaders import (
21
+ FromSingleFileMixin,
22
+ IPAdapterMixin,
23
+ StableDiffusionXLLoraLoaderMixin,
24
+ TextualInversionLoaderMixin,
25
+ )
26
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
27
+ from diffusers.models.attention_processor import (
28
+ AttnProcessor2_0,
29
+ FusedAttnProcessor2_0,
30
+ LoRAAttnProcessor2_0,
31
+ LoRAXFormersAttnProcessor,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.utils import (
37
+ USE_PEFT_BACKEND,
38
+ deprecate,
39
+ is_invisible_watermark_available,
40
+ is_torch_xla_available,
41
+ logging,
42
+ replace_example_docstring,
43
+ scale_lora_layers,
44
+ unscale_lora_layers,
45
+ )
46
+ from diffusers.utils.torch_utils import randn_tensor
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
48
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
49
+
50
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from diffusers import StableDiffusionXLPipeline
59
+
60
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
61
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
62
+ ... )
63
+ >>> pipe = pipe.to("cuda")
64
+
65
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
66
+ >>> image = pipe(prompt).images[0]
67
+ ```
68
+ """
69
+
70
+ # Gaussian blur
71
+ def gaussian_blur_2d(img, kernel_size, sigma):
72
+ height = img.shape[-1]
73
+ kernel_size = min(kernel_size, height - (height % 2 - 1))
74
+ ksize_half = (kernel_size - 1) * 0.5
75
+
76
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
77
+
78
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
79
+
80
+ x_kernel = pdf / pdf.sum()
81
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
82
+
83
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
84
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
85
+
86
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
87
+
88
+ img = F.pad(img, padding, mode="reflect")
89
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
90
+
91
+ return img
92
+
93
+
94
+ class SEGCFGSelfAttnProcessor:
95
+ r"""
96
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
97
+ """
98
+
99
+ def __init__(self, blur_sigma=1.0, do_cfg=True, inf_blur_threshold=9999.0):
100
+ if not hasattr(F, "scaled_dot_product_attention"):
101
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
102
+ self.blur_sigma = blur_sigma
103
+ self.do_cfg = do_cfg
104
+ if self.blur_sigma > inf_blur_threshold:
105
+ self.inf_blur = True
106
+ else:
107
+ self.inf_blur = False
108
+
109
+ def __call__(
110
+ self,
111
+ attn: Attention,
112
+ hidden_states: torch.FloatTensor,
113
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
114
+ attention_mask: Optional[torch.FloatTensor] = None,
115
+ temb: Optional[torch.FloatTensor] = None,
116
+ *args,
117
+ **kwargs,
118
+ ) -> torch.FloatTensor:
119
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
120
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
121
+ deprecate("scale", "1.0.0", deprecation_message)
122
+
123
+ residual = hidden_states
124
+ if attn.spatial_norm is not None:
125
+ hidden_states = attn.spatial_norm(hidden_states, temb)
126
+
127
+ input_ndim = hidden_states.ndim
128
+
129
+ if input_ndim == 4:
130
+ batch_size, channel, height, width = hidden_states.shape
131
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
132
+
133
+ batch_size, sequence_length, _ = (
134
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
135
+ )
136
+
137
+ if attention_mask is not None:
138
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
139
+ # scaled_dot_product_attention expects attention_mask shape to be
140
+ # (batch, heads, source_length, target_length)
141
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
142
+
143
+ if attn.group_norm is not None:
144
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
145
+
146
+ query = attn.to_q(hidden_states)
147
+
148
+ if encoder_hidden_states is None:
149
+ encoder_hidden_states = hidden_states
150
+ elif attn.norm_cross:
151
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
152
+
153
+ key = attn.to_k(encoder_hidden_states)
154
+ value = attn.to_v(encoder_hidden_states)
155
+
156
+ inner_dim = key.shape[-1]
157
+ head_dim = inner_dim // attn.heads
158
+
159
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
160
+
161
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
162
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
163
+
164
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
165
+ # TODO: add support for attn.scale when we move to Torch 2.1
166
+ height = width = math.isqrt(query.shape[2])
167
+ if self.do_cfg:
168
+ query_uncond, query_org, query_ptb = query.chunk(3)
169
+ query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, attn.heads * head_dim, height, width)
170
+
171
+ if not self.inf_blur:
172
+ kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
173
+ query_ptb = gaussian_blur_2d(query_ptb, kernel_size, self.blur_sigma)
174
+ else:
175
+ query_ptb[:] = query_ptb.mean(dim=(-2, -1), keepdim=True)
176
+
177
+ query_ptb = query_ptb.view(batch_size//3, attn.heads, head_dim, height * width).permute(0, 1, 3, 2)
178
+ query = torch.cat((query_uncond, query_org, query_ptb), dim=0)
179
+ else:
180
+ query_org, query_ptb = query.chunk(2)
181
+ query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//2, attn.heads * head_dim, height, width)
182
+
183
+ if not self.inf_blur:
184
+ kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
185
+ query_ptb = gaussian_blur_2d(query_ptb, kernel_size, self.blur_sigma)
186
+ else:
187
+ query_ptb[:] = query_ptb.mean(dim=(-2, -1), keepdim=True)
188
+
189
+ query_ptb = query_ptb.view(batch_size//2, attn.heads, head_dim, height * width).permute(0, 1, 3, 2)
190
+ query = torch.cat((query_org, query_ptb), dim=0)
191
+
192
+ hidden_states = F.scaled_dot_product_attention(
193
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
194
+ )
195
+
196
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
197
+ hidden_states = hidden_states.to(query.dtype)
198
+
199
+ # linear proj
200
+ hidden_states = attn.to_out[0](hidden_states)
201
+ # dropout
202
+ hidden_states = attn.to_out[1](hidden_states)
203
+
204
+ if input_ndim == 4:
205
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
206
+
207
+ if attn.residual_connection:
208
+ hidden_states = hidden_states + residual
209
+
210
+ hidden_states = hidden_states / attn.rescale_output_factor
211
+
212
+ return hidden_states
213
+
214
+ if is_invisible_watermark_available():
215
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
216
+
217
+ if is_torch_xla_available():
218
+ import torch_xla.core.xla_model as xm
219
+
220
+ XLA_AVAILABLE = True
221
+ else:
222
+ XLA_AVAILABLE = False
223
+
224
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
225
+
226
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
227
+ """
228
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
229
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
230
+ """
231
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
232
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
233
+ # rescale the results from guidance (fixes overexposure)
234
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
235
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
236
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
237
+ return noise_cfg
238
+
239
+
240
+ def retrieve_timesteps(
241
+ scheduler,
242
+ num_inference_steps: Optional[int] = None,
243
+ device: Optional[Union[str, torch.device]] = None,
244
+ timesteps: Optional[List[int]] = None,
245
+ **kwargs,
246
+ ):
247
+ """
248
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
249
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
250
+
251
+ Args:
252
+ scheduler (`SchedulerMixin`):
253
+ The scheduler to get timesteps from.
254
+ num_inference_steps (`int`):
255
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
256
+ must be `None`.
257
+ device (`str` or `torch.device`, *optional*):
258
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
259
+ timesteps (`List[int]`, *optional*):
260
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
261
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
262
+ must be `None`.
263
+
264
+ Returns:
265
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
266
+ second element is the number of inference steps.
267
+ """
268
+ if timesteps is not None:
269
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
270
+ if not accepts_timesteps:
271
+ raise ValueError(
272
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
273
+ f" timestep schedules. Please check whether you are using the correct scheduler."
274
+ )
275
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
276
+ timesteps = scheduler.timesteps
277
+ num_inference_steps = len(timesteps)
278
+ else:
279
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
280
+ timesteps = scheduler.timesteps
281
+ return timesteps, num_inference_steps
282
+
283
+ class StableDiffusionXLSEGPipeline(
284
+ DiffusionPipeline,
285
+ StableDiffusionMixin,
286
+ FromSingleFileMixin,
287
+ StableDiffusionXLLoraLoaderMixin,
288
+ TextualInversionLoaderMixin,
289
+ IPAdapterMixin,
290
+ ):
291
+ r"""
292
+ Pipeline for text-to-image generation using Stable Diffusion XL.
293
+
294
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
295
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
296
+
297
+ The pipeline also inherits the following loading methods:
298
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
299
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
300
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
301
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
302
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
303
+
304
+ Args:
305
+ vae ([`AutoencoderKL`]):
306
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
307
+ text_encoder ([`CLIPTextModel`]):
308
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
309
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
310
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
311
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
312
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
313
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
314
+ specifically the
315
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
316
+ variant.
317
+ tokenizer (`CLIPTokenizer`):
318
+ Tokenizer of class
319
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
320
+ tokenizer_2 (`CLIPTokenizer`):
321
+ Second Tokenizer of class
322
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
323
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
324
+ scheduler ([`SchedulerMixin`]):
325
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
326
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
327
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
328
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
329
+ `stabilityai/stable-diffusion-xl-base-1-0`.
330
+ add_watermarker (`bool`, *optional*):
331
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
332
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
333
+ watermarker will be used.
334
+ """
335
+
336
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
337
+ _optional_components = [
338
+ "tokenizer",
339
+ "tokenizer_2",
340
+ "text_encoder",
341
+ "text_encoder_2",
342
+ "image_encoder",
343
+ "feature_extractor",
344
+ ]
345
+ _callback_tensor_inputs = [
346
+ "latents",
347
+ "prompt_embeds",
348
+ "negative_prompt_embeds",
349
+ "add_text_embeds",
350
+ "add_time_ids",
351
+ "negative_pooled_prompt_embeds",
352
+ "negative_add_time_ids",
353
+ ]
354
+
355
+ def __init__(
356
+ self,
357
+ vae: AutoencoderKL,
358
+ text_encoder: CLIPTextModel,
359
+ text_encoder_2: CLIPTextModelWithProjection,
360
+ tokenizer: CLIPTokenizer,
361
+ tokenizer_2: CLIPTokenizer,
362
+ unet: UNet2DConditionModel,
363
+ scheduler: KarrasDiffusionSchedulers,
364
+ image_encoder: CLIPVisionModelWithProjection = None,
365
+ feature_extractor: CLIPImageProcessor = None,
366
+ force_zeros_for_empty_prompt: bool = True,
367
+ add_watermarker: Optional[bool] = None,
368
+ ):
369
+ super().__init__()
370
+
371
+ self.register_modules(
372
+ vae=vae,
373
+ text_encoder=text_encoder,
374
+ text_encoder_2=text_encoder_2,
375
+ tokenizer=tokenizer,
376
+ tokenizer_2=tokenizer_2,
377
+ unet=unet,
378
+ scheduler=scheduler,
379
+ image_encoder=image_encoder,
380
+ feature_extractor=feature_extractor,
381
+ )
382
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
383
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
384
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
385
+
386
+ self.default_sample_size = self.unet.config.sample_size
387
+
388
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
389
+
390
+ if add_watermarker:
391
+ self.watermark = StableDiffusionXLWatermarker()
392
+ else:
393
+ self.watermark = None
394
+
395
+ def encode_prompt(
396
+ self,
397
+ prompt: str,
398
+ prompt_2: Optional[str] = None,
399
+ device: Optional[torch.device] = None,
400
+ num_images_per_prompt: int = 1,
401
+ do_classifier_free_guidance: bool = True,
402
+ negative_prompt: Optional[str] = None,
403
+ negative_prompt_2: Optional[str] = None,
404
+ prompt_embeds: Optional[torch.FloatTensor] = None,
405
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
406
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
407
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
408
+ lora_scale: Optional[float] = None,
409
+ clip_skip: Optional[int] = None,
410
+ ):
411
+ r"""
412
+ Encodes the prompt into text encoder hidden states.
413
+
414
+ Args:
415
+ prompt (`str` or `List[str]`, *optional*):
416
+ prompt to be encoded
417
+ prompt_2 (`str` or `List[str]`, *optional*):
418
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
419
+ used in both text-encoders
420
+ device: (`torch.device`):
421
+ torch device
422
+ num_images_per_prompt (`int`):
423
+ number of images that should be generated per prompt
424
+ do_classifier_free_guidance (`bool`):
425
+ whether to use classifier free guidance or not
426
+ negative_prompt (`str` or `List[str]`, *optional*):
427
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
428
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
429
+ less than `1`).
430
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
431
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
432
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
433
+ prompt_embeds (`torch.FloatTensor`, *optional*):
434
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
435
+ provided, text embeddings will be generated from `prompt` input argument.
436
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
437
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
438
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
439
+ argument.
440
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
441
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
442
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
443
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
444
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
445
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
446
+ input argument.
447
+ lora_scale (`float`, *optional*):
448
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
449
+ clip_skip (`int`, *optional*):
450
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
451
+ the output of the pre-final layer will be used for computing the prompt embeddings.
452
+ """
453
+ device = device or self._execution_device
454
+
455
+ # set lora scale so that monkey patched LoRA
456
+ # function of text encoder can correctly access it
457
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
458
+ self._lora_scale = lora_scale
459
+
460
+ # dynamically adjust the LoRA scale
461
+ if self.text_encoder is not None:
462
+ if not USE_PEFT_BACKEND:
463
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
464
+ else:
465
+ scale_lora_layers(self.text_encoder, lora_scale)
466
+
467
+ if self.text_encoder_2 is not None:
468
+ if not USE_PEFT_BACKEND:
469
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
470
+ else:
471
+ scale_lora_layers(self.text_encoder_2, lora_scale)
472
+
473
+ prompt = [prompt] if isinstance(prompt, str) else prompt
474
+
475
+ if prompt is not None:
476
+ batch_size = len(prompt)
477
+ else:
478
+ batch_size = prompt_embeds.shape[0]
479
+
480
+ # Define tokenizers and text encoders
481
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
482
+ text_encoders = (
483
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
484
+ )
485
+
486
+ if prompt_embeds is None:
487
+ prompt_2 = prompt_2 or prompt
488
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
489
+
490
+ # textual inversion: process multi-vector tokens if necessary
491
+ prompt_embeds_list = []
492
+ prompts = [prompt, prompt_2]
493
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
494
+ if isinstance(self, TextualInversionLoaderMixin):
495
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
496
+
497
+ text_inputs = tokenizer(
498
+ prompt,
499
+ padding="max_length",
500
+ max_length=tokenizer.model_max_length,
501
+ truncation=True,
502
+ return_tensors="pt",
503
+ )
504
+
505
+ text_input_ids = text_inputs.input_ids
506
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
507
+
508
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
509
+ text_input_ids, untruncated_ids
510
+ ):
511
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
512
+ logger.warning(
513
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
514
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
515
+ )
516
+
517
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
518
+
519
+ # We are only ALWAYS interested in the pooled output of the final text encoder
520
+ pooled_prompt_embeds = prompt_embeds[0]
521
+ if clip_skip is None:
522
+ prompt_embeds = prompt_embeds.hidden_states[-2]
523
+ else:
524
+ # "2" because SDXL always indexes from the penultimate layer.
525
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
526
+
527
+ prompt_embeds_list.append(prompt_embeds)
528
+
529
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
530
+
531
+ # get unconditional embeddings for classifier free guidance
532
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
533
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
534
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
535
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
536
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
537
+ negative_prompt = negative_prompt or ""
538
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
539
+
540
+ # normalize str to list
541
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
542
+ negative_prompt_2 = (
543
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
544
+ )
545
+
546
+ uncond_tokens: List[str]
547
+ if prompt is not None and type(prompt) is not type(negative_prompt):
548
+ raise TypeError(
549
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
550
+ f" {type(prompt)}."
551
+ )
552
+ elif batch_size != len(negative_prompt):
553
+ raise ValueError(
554
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
555
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
556
+ " the batch size of `prompt`."
557
+ )
558
+ else:
559
+ uncond_tokens = [negative_prompt, negative_prompt_2]
560
+
561
+ negative_prompt_embeds_list = []
562
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
563
+ if isinstance(self, TextualInversionLoaderMixin):
564
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
565
+
566
+ max_length = prompt_embeds.shape[1]
567
+ uncond_input = tokenizer(
568
+ negative_prompt,
569
+ padding="max_length",
570
+ max_length=max_length,
571
+ truncation=True,
572
+ return_tensors="pt",
573
+ )
574
+
575
+ negative_prompt_embeds = text_encoder(
576
+ uncond_input.input_ids.to(device),
577
+ output_hidden_states=True,
578
+ )
579
+ # We are only ALWAYS interested in the pooled output of the final text encoder
580
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
581
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
582
+
583
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
584
+
585
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
586
+
587
+ if self.text_encoder_2 is not None:
588
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
589
+ else:
590
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
591
+
592
+ bs_embed, seq_len, _ = prompt_embeds.shape
593
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
594
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
595
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
596
+
597
+ if do_classifier_free_guidance:
598
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
599
+ seq_len = negative_prompt_embeds.shape[1]
600
+
601
+ if self.text_encoder_2 is not None:
602
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
603
+ else:
604
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
605
+
606
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
607
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
608
+
609
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
610
+ bs_embed * num_images_per_prompt, -1
611
+ )
612
+ if do_classifier_free_guidance:
613
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
614
+ bs_embed * num_images_per_prompt, -1
615
+ )
616
+
617
+ if self.text_encoder is not None:
618
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
619
+ # Retrieve the original scale by scaling back the LoRA layers
620
+ unscale_lora_layers(self.text_encoder, lora_scale)
621
+
622
+ if self.text_encoder_2 is not None:
623
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
624
+ # Retrieve the original scale by scaling back the LoRA layers
625
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
626
+
627
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
628
+
629
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
630
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
631
+ dtype = next(self.image_encoder.parameters()).dtype
632
+
633
+ if not isinstance(image, torch.Tensor):
634
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
635
+
636
+ image = image.to(device=device, dtype=dtype)
637
+ if output_hidden_states:
638
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
639
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
640
+ uncond_image_enc_hidden_states = self.image_encoder(
641
+ torch.zeros_like(image), output_hidden_states=True
642
+ ).hidden_states[-2]
643
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
644
+ num_images_per_prompt, dim=0
645
+ )
646
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
647
+ else:
648
+ image_embeds = self.image_encoder(image).image_embeds
649
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
650
+ uncond_image_embeds = torch.zeros_like(image_embeds)
651
+
652
+ return image_embeds, uncond_image_embeds
653
+
654
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
655
+ def prepare_ip_adapter_image_embeds(
656
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
657
+ ):
658
+ if ip_adapter_image_embeds is None:
659
+ if not isinstance(ip_adapter_image, list):
660
+ ip_adapter_image = [ip_adapter_image]
661
+
662
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
663
+ raise ValueError(
664
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
665
+ )
666
+
667
+ image_embeds = []
668
+ for single_ip_adapter_image, image_proj_layer in zip(
669
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
670
+ ):
671
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
672
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
673
+ single_ip_adapter_image, device, 1, output_hidden_state
674
+ )
675
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
676
+ single_negative_image_embeds = torch.stack(
677
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
678
+ )
679
+
680
+ if do_classifier_free_guidance:
681
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
682
+ single_image_embeds = single_image_embeds.to(device)
683
+
684
+ image_embeds.append(single_image_embeds)
685
+ else:
686
+ repeat_dims = [1]
687
+ image_embeds = []
688
+ for single_image_embeds in ip_adapter_image_embeds:
689
+ if do_classifier_free_guidance:
690
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
691
+ single_image_embeds = single_image_embeds.repeat(
692
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
693
+ )
694
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
695
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
696
+ )
697
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
698
+ else:
699
+ single_image_embeds = single_image_embeds.repeat(
700
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
701
+ )
702
+ image_embeds.append(single_image_embeds)
703
+
704
+ return image_embeds
705
+
706
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
707
+ def prepare_extra_step_kwargs(self, generator, eta):
708
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
709
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
710
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
711
+ # and should be between [0, 1]
712
+
713
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
714
+ extra_step_kwargs = {}
715
+ if accepts_eta:
716
+ extra_step_kwargs["eta"] = eta
717
+
718
+ # check if the scheduler accepts generator
719
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
720
+ if accepts_generator:
721
+ extra_step_kwargs["generator"] = generator
722
+ return extra_step_kwargs
723
+
724
+ def check_inputs(
725
+ self,
726
+ prompt,
727
+ prompt_2,
728
+ height,
729
+ width,
730
+ callback_steps,
731
+ negative_prompt=None,
732
+ negative_prompt_2=None,
733
+ prompt_embeds=None,
734
+ negative_prompt_embeds=None,
735
+ pooled_prompt_embeds=None,
736
+ negative_pooled_prompt_embeds=None,
737
+ ip_adapter_image=None,
738
+ ip_adapter_image_embeds=None,
739
+ callback_on_step_end_tensor_inputs=None,
740
+ ):
741
+ if height % 8 != 0 or width % 8 != 0:
742
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
743
+
744
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
745
+ raise ValueError(
746
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
747
+ f" {type(callback_steps)}."
748
+ )
749
+
750
+ if callback_on_step_end_tensor_inputs is not None and not all(
751
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
752
+ ):
753
+ raise ValueError(
754
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
755
+ )
756
+
757
+ if prompt is not None and prompt_embeds is not None:
758
+ raise ValueError(
759
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
760
+ " only forward one of the two."
761
+ )
762
+ elif prompt_2 is not None and prompt_embeds is not None:
763
+ raise ValueError(
764
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
765
+ " only forward one of the two."
766
+ )
767
+ elif prompt is None and prompt_embeds is None:
768
+ raise ValueError(
769
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
770
+ )
771
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
772
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
773
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
774
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
775
+
776
+ if negative_prompt is not None and negative_prompt_embeds is not None:
777
+ raise ValueError(
778
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
779
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
780
+ )
781
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
782
+ raise ValueError(
783
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
784
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
785
+ )
786
+
787
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
788
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
789
+ raise ValueError(
790
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
791
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
792
+ f" {negative_prompt_embeds.shape}."
793
+ )
794
+
795
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
796
+ raise ValueError(
797
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
798
+ )
799
+
800
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
801
+ raise ValueError(
802
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
803
+ )
804
+
805
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
806
+ raise ValueError(
807
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
808
+ )
809
+
810
+ if ip_adapter_image_embeds is not None:
811
+ if not isinstance(ip_adapter_image_embeds, list):
812
+ raise ValueError(
813
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
814
+ )
815
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
816
+ raise ValueError(
817
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
818
+ )
819
+
820
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
821
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
822
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
823
+ if isinstance(generator, list) and len(generator) != batch_size:
824
+ raise ValueError(
825
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
826
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
827
+ )
828
+
829
+ if latents is None:
830
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
831
+ else:
832
+ latents = latents.to(device)
833
+
834
+ # scale the initial noise by the standard deviation required by the scheduler
835
+ latents = latents * self.scheduler.init_noise_sigma
836
+ return latents
837
+
838
+ def _get_add_time_ids(
839
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
840
+ ):
841
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
842
+
843
+ passed_add_embed_dim = (
844
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
845
+ )
846
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
847
+
848
+ if expected_add_embed_dim != passed_add_embed_dim:
849
+ raise ValueError(
850
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
851
+ )
852
+
853
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
854
+ return add_time_ids
855
+
856
+ def upcast_vae(self):
857
+ dtype = self.vae.dtype
858
+ self.vae.to(dtype=torch.float32)
859
+ use_torch_2_0_or_xformers = isinstance(
860
+ self.vae.decoder.mid_block.attentions[0].processor,
861
+ (
862
+ AttnProcessor2_0,
863
+ XFormersAttnProcessor,
864
+ LoRAXFormersAttnProcessor,
865
+ LoRAAttnProcessor2_0,
866
+ FusedAttnProcessor2_0,
867
+ ),
868
+ )
869
+ # if xformers or torch_2_0 is used attention block does not need
870
+ # to be in float32 which can save lots of memory
871
+ if use_torch_2_0_or_xformers:
872
+ self.vae.post_quant_conv.to(dtype)
873
+ self.vae.decoder.conv_in.to(dtype)
874
+ self.vae.decoder.mid_block.to(dtype)
875
+
876
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
877
+ def get_guidance_scale_embedding(
878
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
879
+ ) -> torch.FloatTensor:
880
+ """
881
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
882
+
883
+ Args:
884
+ w (`torch.Tensor`):
885
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
886
+ embedding_dim (`int`, *optional*, defaults to 512):
887
+ Dimension of the embeddings to generate.
888
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
889
+ Data type of the generated embeddings.
890
+
891
+ Returns:
892
+ `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
893
+ """
894
+ assert len(w.shape) == 1
895
+ w = w * 1000.0
896
+
897
+ half_dim = embedding_dim // 2
898
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
899
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
900
+ emb = w.to(dtype)[:, None] * emb[None, :]
901
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
902
+ if embedding_dim % 2 == 1: # zero pad
903
+ emb = torch.nn.functional.pad(emb, (0, 1))
904
+ assert emb.shape == (w.shape[0], embedding_dim)
905
+ return emb
906
+
907
+ def pred_z0(self, sample, model_output, timestep):
908
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
909
+
910
+ beta_prod_t = 1 - alpha_prod_t
911
+ if self.scheduler.config.prediction_type == "epsilon":
912
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
913
+ elif self.scheduler.config.prediction_type == "sample":
914
+ pred_original_sample = model_output
915
+ elif self.scheduler.config.prediction_type == "v_prediction":
916
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
917
+ # predict V
918
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
919
+ else:
920
+ raise ValueError(
921
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
922
+ " or `v_prediction`"
923
+ )
924
+
925
+ return pred_original_sample
926
+
927
+ def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
928
+ pred_z0 = self.pred_z0(latents, noise_pred, t)
929
+ pred_x0 = self.vae.decode(
930
+ pred_z0 / self.vae.config.scaling_factor,
931
+ return_dict=False,
932
+ generator=generator
933
+ )[0]
934
+ #pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
935
+ do_denormalize = [True] * pred_x0.shape[0]
936
+ pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
937
+
938
+ return pred_x0
939
+
940
+ @property
941
+ def guidance_scale(self):
942
+ return self._guidance_scale
943
+
944
+ @property
945
+ def guidance_rescale(self):
946
+ return self._guidance_rescale
947
+
948
+ @property
949
+ def clip_skip(self):
950
+ return self._clip_skip
951
+
952
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
953
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
954
+ # corresponds to doing no classifier free guidance.
955
+ @property
956
+ def do_classifier_free_guidance(self):
957
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
958
+
959
+ @property
960
+ def cross_attention_kwargs(self):
961
+ return self._cross_attention_kwargs
962
+
963
+ @property
964
+ def denoising_end(self):
965
+ return self._denoising_end
966
+
967
+ @property
968
+ def num_timesteps(self):
969
+ return self._num_timesteps
970
+
971
+ @property
972
+ def interrupt(self):
973
+ return self._interrupt
974
+
975
+ @property
976
+ def seg_scale(self):
977
+ return self._seg_scale
978
+
979
+ @property
980
+ def do_seg(self):
981
+ return self._seg_scale > 0
982
+
983
+ @property
984
+ def seg_applied_layers(self):
985
+ return self._seg_applied_layers
986
+
987
+ @property
988
+ def seg_applied_layers_index(self):
989
+ return self._seg_applied_layers_index
990
+
991
+ @torch.no_grad()
992
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
993
+ def __call__(
994
+ self,
995
+ prompt: Union[str, List[str]] = None,
996
+ prompt_2: Optional[Union[str, List[str]]] = None,
997
+ height: Optional[int] = None,
998
+ width: Optional[int] = None,
999
+ num_inference_steps: int = 50,
1000
+ timesteps: List[int] = None,
1001
+ denoising_end: Optional[float] = None,
1002
+ guidance_scale: float = 5.0,
1003
+ seg_scale: float = 3.0,
1004
+ seg_blur_sigma: float = 9999999.0,
1005
+ seg_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up']
1006
+ seg_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0']
1007
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1008
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1009
+ num_images_per_prompt: Optional[int] = 1,
1010
+ eta: float = 0.0,
1011
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1012
+ latents: Optional[torch.FloatTensor] = None,
1013
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1014
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1015
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1016
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1017
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1018
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1019
+ output_type: Optional[str] = "pil",
1020
+ return_dict: bool = True,
1021
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1022
+ guidance_rescale: float = 0.0,
1023
+ original_size: Optional[Tuple[int, int]] = None,
1024
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1025
+ target_size: Optional[Tuple[int, int]] = None,
1026
+ negative_original_size: Optional[Tuple[int, int]] = None,
1027
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1028
+ negative_target_size: Optional[Tuple[int, int]] = None,
1029
+ clip_skip: Optional[int] = None,
1030
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1031
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1032
+ **kwargs,
1033
+ ):
1034
+ r"""
1035
+ Function invoked when calling the pipeline for generation.
1036
+
1037
+ Args:
1038
+ prompt (`str` or `List[str]`, *optional*):
1039
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1040
+ instead.
1041
+ prompt_2 (`str` or `List[str]`, *optional*):
1042
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1043
+ used in both text-encoders
1044
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1045
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1046
+ Anything below 512 pixels won't work well for
1047
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1048
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1049
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1050
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1051
+ Anything below 512 pixels won't work well for
1052
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1053
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1054
+ num_inference_steps (`int`, *optional*, defaults to 50):
1055
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1056
+ expense of slower inference.
1057
+ timesteps (`List[int]`, *optional*):
1058
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1059
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1060
+ passed will be used. Must be in descending order.
1061
+ denoising_end (`float`, *optional*):
1062
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1063
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1064
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1065
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1066
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1067
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1068
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1069
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1070
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1071
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1072
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1073
+ usually at the expense of lower image quality.
1074
+ seg_scale (`float`, *optional*, defaults to 3.0):
1075
+ The scale of SEG. Generally fixed to 3.0. Increase it if the result with infinite blur is still not
1076
+ satisfactory.
1077
+ seg_blur_sigma (`float`, *optional*, defaults to 9999999.0):
1078
+ The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
1079
+ infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
1080
+ seg_applied_layers (`List[str]`, *optional*):
1081
+ The layer(s) in which we blur the attention weights. ['mid'] by default.
1082
+ seg_applied_layers_index (`List[str]`, *optional*):
1083
+ The specific layer(s) in which we blur the attention weights. None by default.
1084
+ negative_prompt (`str` or `List[str]`, *optional*):
1085
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1086
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1087
+ less than `1`).
1088
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1089
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1090
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1091
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1092
+ The number of images to generate per prompt.
1093
+ eta (`float`, *optional*, defaults to 0.0):
1094
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1095
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1096
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1097
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1098
+ to make generation deterministic.
1099
+ latents (`torch.FloatTensor`, *optional*):
1100
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1101
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1102
+ tensor will ge generated by sampling using the supplied random `generator`.
1103
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1104
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1105
+ provided, text embeddings will be generated from `prompt` input argument.
1106
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1107
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1108
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1109
+ argument.
1110
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1111
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1112
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1113
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1114
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1115
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1116
+ input argument.
1117
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1118
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1119
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1120
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1121
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1122
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1123
+ output_type (`str`, *optional*, defaults to `"pil"`):
1124
+ The output format of the generate image. Choose between
1125
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1126
+ return_dict (`bool`, *optional*, defaults to `True`):
1127
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1128
+ of a plain tuple.
1129
+ cross_attention_kwargs (`dict`, *optional*):
1130
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1131
+ `self.processor` in
1132
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1133
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1134
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1135
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
1136
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
1137
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
1138
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1139
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1140
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1141
+ explained in section 2.2 of
1142
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1143
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1144
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1145
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1146
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1147
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1148
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1149
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1150
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1151
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1152
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1153
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1154
+ micro-conditioning as explained in section 2.2 of
1155
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1156
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1157
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1158
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1159
+ micro-conditioning as explained in section 2.2 of
1160
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1161
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1162
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1163
+ To negatively condition the generation process based on a target image resolution. It should be as same
1164
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1165
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1166
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1167
+ callback_on_step_end (`Callable`, *optional*):
1168
+ A function that calls at the end of each denoising steps during the inference. The function is called
1169
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1170
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1171
+ `callback_on_step_end_tensor_inputs`.
1172
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1173
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1174
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1175
+ `._callback_tensor_inputs` attribute of your pipeline class.
1176
+
1177
+ Examples:
1178
+
1179
+ Returns:
1180
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1181
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1182
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1183
+ """
1184
+
1185
+ callback = kwargs.pop("callback", None)
1186
+ callback_steps = kwargs.pop("callback_steps", None)
1187
+
1188
+ if callback is not None:
1189
+ deprecate(
1190
+ "callback",
1191
+ "1.0.0",
1192
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1193
+ )
1194
+ if callback_steps is not None:
1195
+ deprecate(
1196
+ "callback_steps",
1197
+ "1.0.0",
1198
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1199
+ )
1200
+
1201
+ # 0. Default height and width to unet
1202
+ height = height or self.default_sample_size * self.vae_scale_factor
1203
+ width = width or self.default_sample_size * self.vae_scale_factor
1204
+
1205
+ original_size = original_size or (height, width)
1206
+ target_size = target_size or (height, width)
1207
+
1208
+ # 1. Check inputs. Raise error if not correct
1209
+ self.check_inputs(
1210
+ prompt,
1211
+ prompt_2,
1212
+ height,
1213
+ width,
1214
+ callback_steps,
1215
+ negative_prompt,
1216
+ negative_prompt_2,
1217
+ prompt_embeds,
1218
+ negative_prompt_embeds,
1219
+ pooled_prompt_embeds,
1220
+ negative_pooled_prompt_embeds,
1221
+ ip_adapter_image,
1222
+ ip_adapter_image_embeds,
1223
+ callback_on_step_end_tensor_inputs,
1224
+ )
1225
+
1226
+ self._guidance_scale = guidance_scale
1227
+ self._guidance_rescale = guidance_rescale
1228
+ self._clip_skip = clip_skip
1229
+ self._cross_attention_kwargs = cross_attention_kwargs
1230
+ self._denoising_end = denoising_end
1231
+ self._interrupt = False
1232
+
1233
+ self._seg_scale = seg_scale
1234
+ self._seg_applied_layers = seg_applied_layers
1235
+ self._seg_applied_layers_index = seg_applied_layers_index
1236
+
1237
+ # 2. Define call parameters
1238
+ if prompt is not None and isinstance(prompt, str):
1239
+ batch_size = 1
1240
+ elif prompt is not None and isinstance(prompt, list):
1241
+ batch_size = len(prompt)
1242
+ else:
1243
+ batch_size = prompt_embeds.shape[0]
1244
+
1245
+ device = self._execution_device
1246
+
1247
+ # 3. Encode input prompt
1248
+ lora_scale = (
1249
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1250
+ )
1251
+
1252
+ (
1253
+ prompt_embeds,
1254
+ negative_prompt_embeds,
1255
+ pooled_prompt_embeds,
1256
+ negative_pooled_prompt_embeds,
1257
+ ) = self.encode_prompt(
1258
+ prompt=prompt,
1259
+ prompt_2=prompt_2,
1260
+ device=device,
1261
+ num_images_per_prompt=num_images_per_prompt,
1262
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1263
+ negative_prompt=negative_prompt,
1264
+ negative_prompt_2=negative_prompt_2,
1265
+ prompt_embeds=prompt_embeds,
1266
+ negative_prompt_embeds=negative_prompt_embeds,
1267
+ pooled_prompt_embeds=pooled_prompt_embeds,
1268
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1269
+ lora_scale=lora_scale,
1270
+ clip_skip=self.clip_skip,
1271
+ )
1272
+
1273
+ # 4. Prepare timesteps
1274
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1275
+
1276
+ # 5. Prepare latent variables
1277
+ num_channels_latents = self.unet.config.in_channels
1278
+ latents = self.prepare_latents(
1279
+ batch_size * num_images_per_prompt,
1280
+ num_channels_latents,
1281
+ height,
1282
+ width,
1283
+ prompt_embeds.dtype,
1284
+ device,
1285
+ generator,
1286
+ latents,
1287
+ )
1288
+
1289
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1290
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1291
+
1292
+ # 7. Prepare added time ids & embeddings
1293
+ add_text_embeds = pooled_prompt_embeds
1294
+ if self.text_encoder_2 is None:
1295
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1296
+ else:
1297
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1298
+
1299
+ add_time_ids = self._get_add_time_ids(
1300
+ original_size,
1301
+ crops_coords_top_left,
1302
+ target_size,
1303
+ dtype=prompt_embeds.dtype,
1304
+ text_encoder_projection_dim=text_encoder_projection_dim,
1305
+ )
1306
+ if negative_original_size is not None and negative_target_size is not None:
1307
+ negative_add_time_ids = self._get_add_time_ids(
1308
+ negative_original_size,
1309
+ negative_crops_coords_top_left,
1310
+ negative_target_size,
1311
+ dtype=prompt_embeds.dtype,
1312
+ text_encoder_projection_dim=text_encoder_projection_dim,
1313
+ )
1314
+ else:
1315
+ negative_add_time_ids = add_time_ids
1316
+
1317
+ #cfg
1318
+ if self.do_classifier_free_guidance and not self.do_seg:
1319
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1320
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1321
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1322
+ #seg
1323
+ elif not self.do_classifier_free_guidance and self.do_seg:
1324
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
1325
+ add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
1326
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
1327
+ #both
1328
+ elif self.do_classifier_free_guidance and self.do_seg:
1329
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
1330
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds, add_text_embeds], dim=0)
1331
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids, add_time_ids], dim=0)
1332
+
1333
+ prompt_embeds = prompt_embeds.to(device)
1334
+ add_text_embeds = add_text_embeds.to(device)
1335
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1336
+
1337
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1338
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1339
+ ip_adapter_image,
1340
+ ip_adapter_image_embeds,
1341
+ device,
1342
+ batch_size * num_images_per_prompt,
1343
+ self.do_classifier_free_guidance,
1344
+ )
1345
+
1346
+ # 8. Denoising loop
1347
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1348
+
1349
+ # 8.1 Apply denoising_end
1350
+ if (
1351
+ self.denoising_end is not None
1352
+ and isinstance(self.denoising_end, float)
1353
+ and self.denoising_end > 0
1354
+ and self.denoising_end < 1
1355
+ ):
1356
+ discrete_timestep_cutoff = int(
1357
+ round(
1358
+ self.scheduler.config.num_train_timesteps
1359
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1360
+ )
1361
+ )
1362
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1363
+ timesteps = timesteps[:num_inference_steps]
1364
+
1365
+ # 9. Optionally get Guidance Scale Embedding
1366
+ timestep_cond = None
1367
+ if self.unet.config.time_cond_proj_dim is not None:
1368
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1369
+ timestep_cond = self.get_guidance_scale_embedding(
1370
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1371
+ ).to(device=device, dtype=latents.dtype)
1372
+
1373
+ # 10. Create down mid and up layer lists
1374
+ if self.do_seg:
1375
+ down_layers = []
1376
+ mid_layers = []
1377
+ up_layers = []
1378
+ for name, module in self.unet.named_modules():
1379
+ if 'attn1' in name and 'to' not in name:
1380
+ layer_type = name.split('.')[0].split('_')[0]
1381
+ if layer_type == 'down':
1382
+ down_layers.append(module)
1383
+ elif layer_type == 'mid':
1384
+ mid_layers.append(module)
1385
+ elif layer_type == 'up':
1386
+ up_layers.append(module)
1387
+ else:
1388
+ raise ValueError(f"Invalid layer type: {layer_type}")
1389
+
1390
+ self._num_timesteps = len(timesteps)
1391
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1392
+ for i, t in enumerate(timesteps):
1393
+ if self.interrupt:
1394
+ continue
1395
+
1396
+ #cfg
1397
+ if self.do_classifier_free_guidance and not self.do_seg:
1398
+ latent_model_input = torch.cat([latents] * 2)
1399
+ #seg
1400
+ elif not self.do_classifier_free_guidance and self.do_seg:
1401
+ latent_model_input = torch.cat([latents] * 2)
1402
+ #both
1403
+ elif self.do_classifier_free_guidance and self.do_seg:
1404
+ latent_model_input = torch.cat([latents] * 3)
1405
+ #no
1406
+ else:
1407
+ latent_model_input = latents
1408
+
1409
+ # change attention layer in UNet if use SEG
1410
+ if self.do_seg:
1411
+
1412
+ replace_processor = SEGCFGSelfAttnProcessor(blur_sigma=seg_blur_sigma, do_cfg=self.do_classifier_free_guidance)
1413
+
1414
+ if self.seg_applied_layers_index:
1415
+ drop_layers = self.seg_applied_layers_index
1416
+ for drop_layer in drop_layers:
1417
+ layer_number = int(drop_layer[1:])
1418
+ try:
1419
+ if drop_layer[0] == 'd':
1420
+ down_layers[layer_number].processor = replace_processor
1421
+ elif drop_layer[0] == 'm':
1422
+ mid_layers[layer_number].processor = replace_processor
1423
+ elif drop_layer[0] == 'u':
1424
+ up_layers[layer_number].processor = replace_processor
1425
+ else:
1426
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1427
+ except IndexError:
1428
+ raise ValueError(
1429
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1430
+ )
1431
+ elif self.seg_applied_layers:
1432
+ drop_full_layers = self.seg_applied_layers
1433
+ for drop_full_layer in drop_full_layers:
1434
+ try:
1435
+ if drop_full_layer == "down":
1436
+ for down_layer in down_layers:
1437
+ down_layer.processor = replace_processor
1438
+ elif drop_full_layer == "mid":
1439
+ for mid_layer in mid_layers:
1440
+ mid_layer.processor = replace_processor
1441
+ elif drop_full_layer == "up":
1442
+ for up_layer in up_layers:
1443
+ up_layer.processor = replace_processor
1444
+ else:
1445
+ raise ValueError(f"Invalid layer type: {drop_full_layer}")
1446
+ except IndexError:
1447
+ raise ValueError(
1448
+ f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `seg_applied_layers_index`"
1449
+ )
1450
+
1451
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1452
+
1453
+ # predict the noise residual
1454
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1455
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1456
+ added_cond_kwargs["image_embeds"] = image_embeds
1457
+
1458
+ noise_pred = self.unet(
1459
+ latent_model_input,
1460
+ t,
1461
+ encoder_hidden_states=prompt_embeds,
1462
+ timestep_cond=timestep_cond,
1463
+ cross_attention_kwargs=self.cross_attention_kwargs,
1464
+ added_cond_kwargs=added_cond_kwargs,
1465
+ return_dict=False,
1466
+ )[0]
1467
+
1468
+ # perform guidance
1469
+ if self.do_classifier_free_guidance and not self.do_seg:
1470
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1471
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1472
+ # seg
1473
+ elif not self.do_classifier_free_guidance and self.do_seg:
1474
+ noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
1475
+
1476
+ signal_scale = self.seg_scale
1477
+
1478
+ noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
1479
+
1480
+ # both
1481
+ elif self.do_classifier_free_guidance and self.do_seg:
1482
+
1483
+ noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
1484
+
1485
+ signal_scale = self.seg_scale
1486
+
1487
+ noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb)
1488
+
1489
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1490
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1491
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1492
+
1493
+ # compute the previous noisy sample x_t -> x_t-1
1494
+ latents_dtype = latents.dtype
1495
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1496
+ if latents.dtype != latents_dtype:
1497
+ if torch.backends.mps.is_available():
1498
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1499
+ latents = latents.to(latents_dtype)
1500
+
1501
+ if callback_on_step_end is not None:
1502
+ callback_kwargs = {}
1503
+ for k in callback_on_step_end_tensor_inputs:
1504
+ callback_kwargs[k] = locals()[k]
1505
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1506
+
1507
+ latents = callback_outputs.pop("latents", latents)
1508
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1509
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1510
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1511
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1512
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1513
+ )
1514
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1515
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1516
+
1517
+ # call the callback, if provided
1518
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1519
+ progress_bar.update()
1520
+ if callback is not None and i % callback_steps == 0:
1521
+ step_idx = i // getattr(self.scheduler, "order", 1)
1522
+ callback(step_idx, t, latents)
1523
+
1524
+ if XLA_AVAILABLE:
1525
+ xm.mark_step()
1526
+
1527
+ if not output_type == "latent":
1528
+ # make sure the VAE is in float32 mode, as it overflows in float16
1529
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1530
+
1531
+ if needs_upcasting:
1532
+ self.upcast_vae()
1533
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1534
+ elif latents.dtype != self.vae.dtype:
1535
+ if torch.backends.mps.is_available():
1536
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1537
+ self.vae = self.vae.to(latents.dtype)
1538
+
1539
+ # unscale/denormalize the latents
1540
+ # denormalize with the mean and std if available and not None
1541
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1542
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1543
+ if has_latents_mean and has_latents_std:
1544
+ latents_mean = (
1545
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1546
+ )
1547
+ latents_std = (
1548
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1549
+ )
1550
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1551
+ else:
1552
+ latents = latents / self.vae.config.scaling_factor
1553
+
1554
+ image = self.vae.decode(latents, return_dict=False)[0]
1555
+
1556
+ # cast back to fp16 if needed
1557
+ if needs_upcasting:
1558
+ self.vae.to(dtype=torch.float16)
1559
+ else:
1560
+ image = latents
1561
+
1562
+ if not output_type == "latent":
1563
+ # apply watermark if available
1564
+ if self.watermark is not None:
1565
+ image = self.watermark.apply_watermark(image)
1566
+
1567
+ image = self.image_processor.postprocess(image, output_type=output_type)
1568
+
1569
+ # Offload all models
1570
+ self.maybe_free_model_hooks()
1571
+
1572
+ if not return_dict:
1573
+ return (image,)
1574
+
1575
+ #Change the attention layers back to original ones after SEG was applied
1576
+ if self.do_seg:
1577
+ if self.seg_applied_layers_index:
1578
+ drop_layers = self.seg_applied_layers_index
1579
+ for drop_layer in drop_layers:
1580
+ layer_number = int(drop_layer[1:])
1581
+ try:
1582
+ if drop_layer[0] == 'd':
1583
+ down_layers[layer_number].processor = AttnProcessor2_0()
1584
+ elif drop_layer[0] == 'm':
1585
+ mid_layers[layer_number].processor = AttnProcessor2_0()
1586
+ elif drop_layer[0] == 'u':
1587
+ up_layers[layer_number].processor = AttnProcessor2_0()
1588
+ else:
1589
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1590
+ except IndexError:
1591
+ raise ValueError(
1592
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1593
+ )
1594
+ elif self.seg_applied_layers:
1595
+ drop_full_layers = self.seg_applied_layers
1596
+ for drop_full_layer in drop_full_layers:
1597
+ try:
1598
+ if drop_full_layer == "down":
1599
+ for down_layer in down_layers:
1600
+ down_layer.processor = AttnProcessor2_0()
1601
+ elif drop_full_layer == "mid":
1602
+ for mid_layer in mid_layers:
1603
+ mid_layer.processor = AttnProcessor2_0()
1604
+ elif drop_full_layer == "up":
1605
+ for up_layer in up_layers:
1606
+ up_layer.processor = AttnProcessor2_0()
1607
+ else:
1608
+ raise ValueError(f"Invalid layer type: {drop_full_layer}")
1609
+ except IndexError:
1610
+ raise ValueError(
1611
+ f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `seg_applied_layers_index`"
1612
+ )
1613
+ return StableDiffusionXLPipelineOutput(images=image)