jyoung105 commited on
Commit
faeab19
1 Parent(s): 0dbfed7

Upload pipeline_stable_diffusion_img2img_pag.py

Browse files
pipeline_stable_diffusion_img2img_pag.py ADDED
@@ -0,0 +1,1557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Smoretalk, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from packaging import version
24
+
25
+ from transformers import (
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionModelWithProjection
31
+ )
32
+
33
+ from diffusers.configuration_utils import FrozenDict
34
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
35
+ from diffusers.loaders import (
36
+ FromSingleFileMixin,
37
+ IPAdapterMixin,
38
+ LoraLoaderMixin,
39
+ TextualInversionLoaderMixin,
40
+ )
41
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
42
+ from diffusers.models.attention_processor import (
43
+ Attention,
44
+ AttnProcessor2_0,
45
+ FusedAttnProcessor2_0,
46
+ LoRAAttnProcessor2_0,
47
+ LoRAXFormersAttnProcessor,
48
+ XFormersAttnProcessor,
49
+ )
50
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
51
+ from diffusers.schedulers import KarrasDiffusionSchedulers
52
+ from diffusers.utils import (
53
+ PIL_INTERPOLATION,
54
+ USE_PEFT_BACKEND,
55
+ deprecate,
56
+ logging,
57
+ replace_example_docstring,
58
+ scale_lora_layers,
59
+ unscale_lora_layers,
60
+ )
61
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
62
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
63
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
64
+ from diffusers.utils.torch_utils import randn_tensor
65
+
66
+
67
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
68
+
69
+ EXAMPLE_DOC_STRING = """
70
+ Examples:
71
+ ```py
72
+ >>> import torch
73
+ >>> from diffusers import StableDiffusionImg2ImgPipeline
74
+ >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
75
+ >>> pipe = pipe.to("cuda")
76
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
77
+ >>> image = pipe(prompt).images[0]
78
+ ```
79
+ """
80
+
81
+
82
+ class PAGIdentitySelfAttnProcessor:
83
+ r"""
84
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
85
+ """
86
+
87
+ def __init__(self):
88
+ if not hasattr(F, "scaled_dot_product_attention"):
89
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
90
+
91
+ def __call__(
92
+ self,
93
+ attn: Attention,
94
+ hidden_states: torch.FloatTensor,
95
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ temb: Optional[torch.FloatTensor] = None,
98
+ *args,
99
+ **kwargs,
100
+ ) -> torch.FloatTensor:
101
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
102
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
103
+ deprecate("scale", "1.0.0", deprecation_message)
104
+
105
+ residual = hidden_states
106
+ if attn.spatial_norm is not None:
107
+ hidden_states = attn.spatial_norm(hidden_states, temb)
108
+
109
+ input_ndim = hidden_states.ndim
110
+ if input_ndim == 4:
111
+ batch_size, channel, height, width = hidden_states.shape
112
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
113
+
114
+ # chunk
115
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
116
+
117
+ # original path
118
+ batch_size, sequence_length, _ = hidden_states_org.shape
119
+
120
+ if attention_mask is not None:
121
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
122
+ # scaled_dot_product_attention expects attention_mask shape to be
123
+ # (batch, heads, source_length, target_length)
124
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
125
+
126
+ if attn.group_norm is not None:
127
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
128
+
129
+ query = attn.to_q(hidden_states_org)
130
+ key = attn.to_k(hidden_states_org)
131
+ value = attn.to_v(hidden_states_org)
132
+
133
+ inner_dim = key.shape[-1]
134
+ head_dim = inner_dim // attn.heads
135
+
136
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
137
+
138
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
139
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
140
+
141
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
142
+ # TODO: add support for attn.scale when we move to Torch 2.1
143
+ hidden_states_org = F.scaled_dot_product_attention(
144
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
145
+ )
146
+
147
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
148
+ hidden_states_org = hidden_states_org.to(query.dtype)
149
+
150
+ # linear proj
151
+ hidden_states_org = attn.to_out[0](hidden_states_org)
152
+ # dropout
153
+ hidden_states_org = attn.to_out[1](hidden_states_org)
154
+
155
+ if input_ndim == 4:
156
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
157
+
158
+ # perturbed path (identity attention)
159
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
160
+
161
+ if attention_mask is not None:
162
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
163
+ # scaled_dot_product_attention expects attention_mask shape to be
164
+ # (batch, heads, source_length, target_length)
165
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
166
+
167
+ if attn.group_norm is not None:
168
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
169
+
170
+ value = attn.to_v(hidden_states_ptb)
171
+
172
+ hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
173
+ # hidden_states_ptb = value
174
+
175
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
176
+
177
+ # linear proj
178
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
179
+ # dropout
180
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
181
+
182
+ if input_ndim == 4:
183
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
184
+
185
+ # cat
186
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
187
+
188
+ if attn.residual_connection:
189
+ hidden_states = hidden_states + residual
190
+
191
+ hidden_states = hidden_states / attn.rescale_output_factor
192
+
193
+ return hidden_states
194
+
195
+
196
+ class PAGCFGIdentitySelfAttnProcessor:
197
+ r"""
198
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
199
+ """
200
+
201
+ def __init__(self):
202
+ if not hasattr(F, "scaled_dot_product_attention"):
203
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
204
+
205
+ def __call__(
206
+ self,
207
+ attn: Attention,
208
+ hidden_states: torch.FloatTensor,
209
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
210
+ attention_mask: Optional[torch.FloatTensor] = None,
211
+ temb: Optional[torch.FloatTensor] = None,
212
+ *args,
213
+ **kwargs,
214
+ ) -> torch.FloatTensor:
215
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
216
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
217
+ deprecate("scale", "1.0.0", deprecation_message)
218
+
219
+ residual = hidden_states
220
+ if attn.spatial_norm is not None:
221
+ hidden_states = attn.spatial_norm(hidden_states, temb)
222
+
223
+ input_ndim = hidden_states.ndim
224
+ if input_ndim == 4:
225
+ batch_size, channel, height, width = hidden_states.shape
226
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
227
+
228
+ # chunk
229
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
230
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
231
+
232
+ # original path
233
+ batch_size, sequence_length, _ = hidden_states_org.shape
234
+
235
+ if attention_mask is not None:
236
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
237
+ # scaled_dot_product_attention expects attention_mask shape to be
238
+ # (batch, heads, source_length, target_length)
239
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
240
+
241
+ if attn.group_norm is not None:
242
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
243
+
244
+ query = attn.to_q(hidden_states_org)
245
+ key = attn.to_k(hidden_states_org)
246
+ value = attn.to_v(hidden_states_org)
247
+
248
+ inner_dim = key.shape[-1]
249
+ head_dim = inner_dim // attn.heads
250
+
251
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
252
+
253
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
254
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
255
+
256
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
257
+ # TODO: add support for attn.scale when we move to Torch 2.1
258
+ hidden_states_org = F.scaled_dot_product_attention(
259
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
260
+ )
261
+
262
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
263
+ hidden_states_org = hidden_states_org.to(query.dtype)
264
+
265
+ # linear proj
266
+ hidden_states_org = attn.to_out[0](hidden_states_org)
267
+ # dropout
268
+ hidden_states_org = attn.to_out[1](hidden_states_org)
269
+
270
+ if input_ndim == 4:
271
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
272
+
273
+ # perturbed path (identity attention)
274
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
275
+
276
+ if attention_mask is not None:
277
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
278
+ # scaled_dot_product_attention expects attention_mask shape to be
279
+ # (batch, heads, source_length, target_length)
280
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
281
+
282
+ if attn.group_norm is not None:
283
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
284
+
285
+ value = attn.to_v(hidden_states_ptb)
286
+ hidden_states_ptb = value
287
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
288
+
289
+ # linear proj
290
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
291
+ # dropout
292
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
293
+
294
+ if input_ndim == 4:
295
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
296
+
297
+ # cat
298
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
299
+
300
+ if attn.residual_connection:
301
+ hidden_states = hidden_states + residual
302
+
303
+ hidden_states = hidden_states / attn.rescale_output_factor
304
+
305
+ return hidden_states
306
+
307
+
308
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
309
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
310
+ """
311
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
312
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
313
+ """
314
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
315
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
316
+ # rescale the results from guidance (fixes overexposure)
317
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
318
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
319
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
320
+ return noise_cfg
321
+
322
+
323
+ def retrieve_latents(
324
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
325
+ ):
326
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
327
+ return encoder_output.latent_dist.sample(generator)
328
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
329
+ return encoder_output.latent_dist.mode()
330
+ elif hasattr(encoder_output, "latents"):
331
+ return encoder_output.latents
332
+ else:
333
+ raise AttributeError("Could not access latents of provided encoder_output")
334
+
335
+
336
+ def preprocess(image):
337
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
338
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
339
+ if isinstance(image, torch.Tensor):
340
+ return image
341
+ elif isinstance(image, PIL.Image.Image):
342
+ image = [image]
343
+
344
+ if isinstance(image[0], PIL.Image.Image):
345
+ w, h = image[0].size
346
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
347
+
348
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
349
+ image = np.concatenate(image, axis=0)
350
+ image = np.array(image).astype(np.float32) / 255.0
351
+ image = image.transpose(0, 3, 1, 2)
352
+ image = 2.0 * image - 1.0
353
+ image = torch.from_numpy(image)
354
+ elif isinstance(image[0], torch.Tensor):
355
+ image = torch.cat(image, dim=0)
356
+ return image
357
+
358
+
359
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
360
+ def retrieve_timesteps(
361
+ scheduler,
362
+ num_inference_steps: Optional[int] = None,
363
+ device: Optional[Union[str, torch.device]] = None,
364
+ timesteps: Optional[List[int]] = None,
365
+ **kwargs,
366
+ ):
367
+ """
368
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
369
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
370
+
371
+ Args:
372
+ scheduler (`SchedulerMixin`):
373
+ The scheduler to get timesteps from.
374
+ num_inference_steps (`int`):
375
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
376
+ must be `None`.
377
+ device (`str` or `torch.device`, *optional*):
378
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
379
+ timesteps (`List[int]`, *optional*):
380
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
381
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
382
+ must be `None`.
383
+
384
+ Returns:
385
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
386
+ second element is the number of inference steps.
387
+ """
388
+ if timesteps is not None:
389
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
390
+ if not accepts_timesteps:
391
+ raise ValueError(
392
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
393
+ f" timestep schedules. Please check whether you are using the correct scheduler."
394
+ )
395
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
396
+ timesteps = scheduler.timesteps
397
+ num_inference_steps = len(timesteps)
398
+ else:
399
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
400
+ timesteps = scheduler.timesteps
401
+ return timesteps, num_inference_steps
402
+
403
+
404
+ class StableDiffusionImg2ImgPipeline(
405
+ DiffusionPipeline,
406
+ StableDiffusionMixin,
407
+ FromSingleFileMixin,
408
+ TextualInversionLoaderMixin,
409
+ LoraLoaderMixin,
410
+ IPAdapterMixin,
411
+ ):
412
+ r"""
413
+ Pipeline for text-guided image-to-image generation using Stable Diffusion.
414
+
415
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
416
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
417
+
418
+ The pipeline also inherits the following loading methods:
419
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
420
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
421
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
422
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
423
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
424
+
425
+ Args:
426
+ vae ([`AutoencoderKL`]):
427
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
428
+ text_encoder ([`~transformers.CLIPTextModel`]):
429
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
430
+ tokenizer ([`~transformers.CLIPTokenizer`]):
431
+ A `CLIPTokenizer` to tokenize text.
432
+ unet ([`UNet2DConditionModel`]):
433
+ A `UNet2DConditionModel` to denoise the encoded image latents.
434
+ scheduler ([`SchedulerMixin`]):
435
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
436
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
437
+ safety_checker ([`StableDiffusionSafetyChecker`]):
438
+ Classification module that estimates whether generated images could be considered offensive or harmful.
439
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
440
+ about a model's potential harms.
441
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
442
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
443
+ """
444
+
445
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
446
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
447
+ _exclude_from_cpu_offload = ["safety_checker"]
448
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
449
+
450
+ def __init__(
451
+ self,
452
+ vae: AutoencoderKL,
453
+ text_encoder: CLIPTextModel,
454
+ tokenizer: CLIPTokenizer,
455
+ unet: UNet2DConditionModel,
456
+ scheduler: KarrasDiffusionSchedulers,
457
+ safety_checker: StableDiffusionSafetyChecker,
458
+ feature_extractor: CLIPImageProcessor,
459
+ image_encoder: CLIPVisionModelWithProjection = None,
460
+ requires_safety_checker: bool = True,
461
+ ):
462
+ super().__init__()
463
+
464
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
465
+ deprecation_message = (
466
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
467
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
468
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
469
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
470
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
471
+ " file"
472
+ )
473
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
474
+ new_config = dict(scheduler.config)
475
+ new_config["steps_offset"] = 1
476
+ scheduler._internal_dict = FrozenDict(new_config)
477
+
478
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
479
+ deprecation_message = (
480
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
481
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
482
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
483
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
484
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
485
+ )
486
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
487
+ new_config = dict(scheduler.config)
488
+ new_config["clip_sample"] = False
489
+ scheduler._internal_dict = FrozenDict(new_config)
490
+
491
+ if safety_checker is None and requires_safety_checker:
492
+ logger.warning(
493
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
494
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
495
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
496
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
497
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
498
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
499
+ )
500
+
501
+ if safety_checker is not None and feature_extractor is None:
502
+ raise ValueError(
503
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
504
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
505
+ )
506
+
507
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
508
+ version.parse(unet.config._diffusers_version).base_version
509
+ ) < version.parse("0.9.0.dev0")
510
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
511
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
512
+ deprecation_message = (
513
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
514
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
515
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
516
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
517
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
518
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
519
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
520
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
521
+ " the `unet/config.json` file"
522
+ )
523
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
524
+ new_config = dict(unet.config)
525
+ new_config["sample_size"] = 64
526
+ unet._internal_dict = FrozenDict(new_config)
527
+
528
+ self.register_modules(
529
+ vae=vae,
530
+ text_encoder=text_encoder,
531
+ tokenizer=tokenizer,
532
+ unet=unet,
533
+ scheduler=scheduler,
534
+ safety_checker=safety_checker,
535
+ feature_extractor=feature_extractor,
536
+ image_encoder=image_encoder,
537
+ )
538
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
539
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
540
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
541
+
542
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
543
+ def _encode_prompt(
544
+ self,
545
+ prompt,
546
+ device,
547
+ num_images_per_prompt,
548
+ do_classifier_free_guidance,
549
+ negative_prompt=None,
550
+ prompt_embeds: Optional[torch.FloatTensor] = None,
551
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
552
+ lora_scale: Optional[float] = None,
553
+ **kwargs,
554
+ ):
555
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
556
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
557
+
558
+ prompt_embeds_tuple = self.encode_prompt(
559
+ prompt=prompt,
560
+ device=device,
561
+ num_images_per_prompt=num_images_per_prompt,
562
+ do_classifier_free_guidance=do_classifier_free_guidance,
563
+ negative_prompt=negative_prompt,
564
+ prompt_embeds=prompt_embeds,
565
+ negative_prompt_embeds=negative_prompt_embeds,
566
+ lora_scale=lora_scale,
567
+ **kwargs,
568
+ )
569
+
570
+ # concatenate for backwards comp
571
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
572
+
573
+ return prompt_embeds
574
+
575
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
576
+ def encode_prompt(
577
+ self,
578
+ prompt,
579
+ device,
580
+ num_images_per_prompt,
581
+ do_classifier_free_guidance,
582
+ negative_prompt=None,
583
+ prompt_embeds: Optional[torch.FloatTensor] = None,
584
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
585
+ lora_scale: Optional[float] = None,
586
+ clip_skip: Optional[int] = None,
587
+ ):
588
+ r"""
589
+ Encodes the prompt into text encoder hidden states.
590
+
591
+ Args:
592
+ prompt (`str` or `List[str]`, *optional*):
593
+ prompt to be encoded
594
+ device: (`torch.device`):
595
+ torch device
596
+ num_images_per_prompt (`int`):
597
+ number of images that should be generated per prompt
598
+ do_classifier_free_guidance (`bool`):
599
+ whether to use classifier free guidance or not
600
+ negative_prompt (`str` or `List[str]`, *optional*):
601
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
602
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
603
+ less than `1`).
604
+ prompt_embeds (`torch.FloatTensor`, *optional*):
605
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
606
+ provided, text embeddings will be generated from `prompt` input argument.
607
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
608
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
609
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
610
+ argument.
611
+ lora_scale (`float`, *optional*):
612
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
613
+ clip_skip (`int`, *optional*):
614
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
615
+ the output of the pre-final layer will be used for computing the prompt embeddings.
616
+ """
617
+ # set lora scale so that monkey patched LoRA
618
+ # function of text encoder can correctly access it
619
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
620
+ self._lora_scale = lora_scale
621
+
622
+ # dynamically adjust the LoRA scale
623
+ if not USE_PEFT_BACKEND:
624
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
625
+ else:
626
+ scale_lora_layers(self.text_encoder, lora_scale)
627
+
628
+ if prompt is not None and isinstance(prompt, str):
629
+ batch_size = 1
630
+ elif prompt is not None and isinstance(prompt, list):
631
+ batch_size = len(prompt)
632
+ else:
633
+ batch_size = prompt_embeds.shape[0]
634
+
635
+ if prompt_embeds is None:
636
+ # textual inversion: process multi-vector tokens if necessary
637
+ if isinstance(self, TextualInversionLoaderMixin):
638
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
639
+
640
+ text_inputs = self.tokenizer(
641
+ prompt,
642
+ padding="max_length",
643
+ max_length=self.tokenizer.model_max_length,
644
+ truncation=True,
645
+ return_tensors="pt",
646
+ )
647
+ text_input_ids = text_inputs.input_ids
648
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
649
+
650
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
651
+ text_input_ids, untruncated_ids
652
+ ):
653
+ removed_text = self.tokenizer.batch_decode(
654
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
655
+ )
656
+ logger.warning(
657
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
658
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
659
+ )
660
+
661
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
662
+ attention_mask = text_inputs.attention_mask.to(device)
663
+ else:
664
+ attention_mask = None
665
+
666
+ if clip_skip is None:
667
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
668
+ prompt_embeds = prompt_embeds[0]
669
+ else:
670
+ prompt_embeds = self.text_encoder(
671
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
672
+ )
673
+ # Access the `hidden_states` first, that contains a tuple of
674
+ # all the hidden states from the encoder layers. Then index into
675
+ # the tuple to access the hidden states from the desired layer.
676
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
677
+ # We also need to apply the final LayerNorm here to not mess with the
678
+ # representations. The `last_hidden_states` that we typically use for
679
+ # obtaining the final prompt representations passes through the LayerNorm
680
+ # layer.
681
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
682
+
683
+ if self.text_encoder is not None:
684
+ prompt_embeds_dtype = self.text_encoder.dtype
685
+ elif self.unet is not None:
686
+ prompt_embeds_dtype = self.unet.dtype
687
+ else:
688
+ prompt_embeds_dtype = prompt_embeds.dtype
689
+
690
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
691
+
692
+ bs_embed, seq_len, _ = prompt_embeds.shape
693
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
694
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
695
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
696
+
697
+ # get unconditional embeddings for classifier free guidance
698
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
699
+ uncond_tokens: List[str]
700
+ if negative_prompt is None:
701
+ uncond_tokens = [""] * batch_size
702
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
703
+ raise TypeError(
704
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
705
+ f" {type(prompt)}."
706
+ )
707
+ elif isinstance(negative_prompt, str):
708
+ uncond_tokens = [negative_prompt]
709
+ elif batch_size != len(negative_prompt):
710
+ raise ValueError(
711
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
712
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
713
+ " the batch size of `prompt`."
714
+ )
715
+ else:
716
+ uncond_tokens = negative_prompt
717
+
718
+ # textual inversion: process multi-vector tokens if necessary
719
+ if isinstance(self, TextualInversionLoaderMixin):
720
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
721
+
722
+ max_length = prompt_embeds.shape[1]
723
+ uncond_input = self.tokenizer(
724
+ uncond_tokens,
725
+ padding="max_length",
726
+ max_length=max_length,
727
+ truncation=True,
728
+ return_tensors="pt",
729
+ )
730
+
731
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
732
+ attention_mask = uncond_input.attention_mask.to(device)
733
+ else:
734
+ attention_mask = None
735
+
736
+ negative_prompt_embeds = self.text_encoder(
737
+ uncond_input.input_ids.to(device),
738
+ attention_mask=attention_mask,
739
+ )
740
+ negative_prompt_embeds = negative_prompt_embeds[0]
741
+
742
+ if do_classifier_free_guidance:
743
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
744
+ seq_len = negative_prompt_embeds.shape[1]
745
+
746
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
747
+
748
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
749
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
750
+
751
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
752
+ # Retrieve the original scale by scaling back the LoRA layers
753
+ unscale_lora_layers(self.text_encoder, lora_scale)
754
+
755
+ return prompt_embeds, negative_prompt_embeds
756
+
757
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
758
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
759
+ dtype = next(self.image_encoder.parameters()).dtype
760
+
761
+ if not isinstance(image, torch.Tensor):
762
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
763
+
764
+ image = image.to(device=device, dtype=dtype)
765
+ if output_hidden_states:
766
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
767
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
768
+ uncond_image_enc_hidden_states = self.image_encoder(
769
+ torch.zeros_like(image), output_hidden_states=True
770
+ ).hidden_states[-2]
771
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
772
+ num_images_per_prompt, dim=0
773
+ )
774
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
775
+ else:
776
+ image_embeds = self.image_encoder(image).image_embeds
777
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
778
+ uncond_image_embeds = torch.zeros_like(image_embeds)
779
+
780
+ return image_embeds, uncond_image_embeds
781
+
782
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
783
+ def prepare_ip_adapter_image_embeds(
784
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
785
+ ):
786
+ if ip_adapter_image_embeds is None:
787
+ if not isinstance(ip_adapter_image, list):
788
+ ip_adapter_image = [ip_adapter_image]
789
+
790
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
791
+ raise ValueError(
792
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
793
+ )
794
+
795
+ image_embeds = []
796
+ for single_ip_adapter_image, image_proj_layer in zip(
797
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
798
+ ):
799
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
800
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
801
+ single_ip_adapter_image, device, 1, output_hidden_state
802
+ )
803
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
804
+ single_negative_image_embeds = torch.stack(
805
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
806
+ )
807
+
808
+ if do_classifier_free_guidance:
809
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
810
+ single_image_embeds = single_image_embeds.to(device)
811
+
812
+ image_embeds.append(single_image_embeds)
813
+ else:
814
+ repeat_dims = [1]
815
+ image_embeds = []
816
+ for single_image_embeds in ip_adapter_image_embeds:
817
+ if do_classifier_free_guidance:
818
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
819
+ single_image_embeds = single_image_embeds.repeat(
820
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
821
+ )
822
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
823
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
824
+ )
825
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
826
+ else:
827
+ single_image_embeds = single_image_embeds.repeat(
828
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
829
+ )
830
+ image_embeds.append(single_image_embeds)
831
+
832
+ return image_embeds
833
+
834
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
835
+ def run_safety_checker(self, image, device, dtype):
836
+ if self.safety_checker is None:
837
+ has_nsfw_concept = None
838
+ else:
839
+ if torch.is_tensor(image):
840
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
841
+ else:
842
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
843
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
844
+ image, has_nsfw_concept = self.safety_checker(
845
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
846
+ )
847
+ return image, has_nsfw_concept
848
+
849
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
850
+ def decode_latents(self, latents):
851
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
852
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
853
+
854
+ latents = 1 / self.vae.config.scaling_factor * latents
855
+ image = self.vae.decode(latents, return_dict=False)[0]
856
+ image = (image / 2 + 0.5).clamp(0, 1)
857
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
858
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
859
+ return image
860
+
861
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
862
+ def prepare_extra_step_kwargs(self, generator, eta):
863
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
864
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
865
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
866
+ # and should be between [0, 1]
867
+
868
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
869
+ extra_step_kwargs = {}
870
+ if accepts_eta:
871
+ extra_step_kwargs["eta"] = eta
872
+
873
+ # check if the scheduler accepts generator
874
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
875
+ if accepts_generator:
876
+ extra_step_kwargs["generator"] = generator
877
+ return extra_step_kwargs
878
+
879
+ def check_inputs(
880
+ self,
881
+ prompt,
882
+ strength,
883
+ callback_steps,
884
+ negative_prompt=None,
885
+ prompt_embeds=None,
886
+ negative_prompt_embeds=None,
887
+ ip_adapter_image=None,
888
+ ip_adapter_image_embeds=None,
889
+ callback_on_step_end_tensor_inputs=None,
890
+ ):
891
+ if strength < 0 or strength > 1:
892
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
893
+
894
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
895
+ raise ValueError(
896
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
897
+ f" {type(callback_steps)}."
898
+ )
899
+
900
+ if callback_on_step_end_tensor_inputs is not None and not all(
901
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
902
+ ):
903
+ raise ValueError(
904
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
905
+ )
906
+
907
+ if prompt is not None and prompt_embeds is not None:
908
+ raise ValueError(
909
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
910
+ " only forward one of the two."
911
+ )
912
+ elif prompt is None and prompt_embeds is None:
913
+ raise ValueError(
914
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
915
+ )
916
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
917
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
918
+
919
+ if negative_prompt is not None and negative_prompt_embeds is not None:
920
+ raise ValueError(
921
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
922
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
923
+ )
924
+
925
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
926
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
927
+ raise ValueError(
928
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
929
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
930
+ f" {negative_prompt_embeds.shape}."
931
+ )
932
+
933
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
934
+ raise ValueError(
935
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
936
+ )
937
+
938
+ if ip_adapter_image_embeds is not None:
939
+ if not isinstance(ip_adapter_image_embeds, list):
940
+ raise ValueError(
941
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
942
+ )
943
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
944
+ raise ValueError(
945
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
946
+ )
947
+
948
+ def get_timesteps(self, num_inference_steps, strength, device):
949
+ # get the original timestep using init_timestep
950
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
951
+
952
+ t_start = max(num_inference_steps - init_timestep, 0)
953
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
954
+ if hasattr(self.scheduler, "set_begin_index"):
955
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
956
+
957
+ return timesteps, num_inference_steps - t_start
958
+
959
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
960
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
961
+ raise ValueError(
962
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
963
+ )
964
+
965
+ image = image.to(device=device, dtype=dtype)
966
+
967
+ batch_size = batch_size * num_images_per_prompt
968
+
969
+ if image.shape[1] == 4:
970
+ init_latents = image
971
+
972
+ else:
973
+ if isinstance(generator, list) and len(generator) != batch_size:
974
+ raise ValueError(
975
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
976
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
977
+ )
978
+
979
+ elif isinstance(generator, list):
980
+ init_latents = [
981
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
982
+ for i in range(batch_size)
983
+ ]
984
+ init_latents = torch.cat(init_latents, dim=0)
985
+ else:
986
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
987
+
988
+ init_latents = self.vae.config.scaling_factor * init_latents
989
+
990
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
991
+ # expand init_latents for batch_size
992
+ deprecation_message = (
993
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
994
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
995
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
996
+ " your script to pass as many initial images as text prompts to suppress this warning."
997
+ )
998
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
999
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
1000
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
1001
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
1002
+ raise ValueError(
1003
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
1004
+ )
1005
+ else:
1006
+ init_latents = torch.cat([init_latents], dim=0)
1007
+
1008
+ shape = init_latents.shape
1009
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1010
+
1011
+ # get latents
1012
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
1013
+ latents = init_latents
1014
+
1015
+ return latents
1016
+
1017
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1018
+ def get_guidance_scale_embedding(
1019
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
1020
+ ) -> torch.FloatTensor:
1021
+ """
1022
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1023
+
1024
+ Args:
1025
+ timesteps (`torch.Tensor`):
1026
+ generate embedding vectors at these timesteps
1027
+ embedding_dim (`int`, *optional*, defaults to 512):
1028
+ dimension of the embeddings to generate
1029
+ dtype:
1030
+ data type of the generated embeddings
1031
+
1032
+ Returns:
1033
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
1034
+ """
1035
+ assert len(w.shape) == 1
1036
+ w = w * 1000.0
1037
+
1038
+ half_dim = embedding_dim // 2
1039
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1040
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1041
+ emb = w.to(dtype)[:, None] * emb[None, :]
1042
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1043
+ if embedding_dim % 2 == 1: # zero pad
1044
+ emb = torch.nn.functional.pad(emb, (0, 1))
1045
+ assert emb.shape == (w.shape[0], embedding_dim)
1046
+ return emb
1047
+
1048
+ def pred_z0(self, sample, model_output, timestep):
1049
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)
1050
+
1051
+ beta_prod_t = 1 - alpha_prod_t
1052
+ if self.scheduler.config.prediction_type == "epsilon":
1053
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
1054
+ elif self.scheduler.config.prediction_type == "sample":
1055
+ pred_original_sample = model_output
1056
+ elif self.scheduler.config.prediction_type == "v_prediction":
1057
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
1058
+ # predict V
1059
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
1060
+ else:
1061
+ raise ValueError(
1062
+ f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
1063
+ " or `v_prediction`"
1064
+ )
1065
+
1066
+ return pred_original_sample
1067
+
1068
+ def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):
1069
+ pred_z0 = self.pred_z0(latents, noise_pred, t)
1070
+ pred_x0 = self.vae.decode(pred_z0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
1071
+ pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)
1072
+ do_denormalize = [True] * pred_x0.shape[0]
1073
+ pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)
1074
+
1075
+ return pred_x0
1076
+
1077
+ @property
1078
+ def guidance_scale(self):
1079
+ return self._guidance_scale
1080
+
1081
+ @property
1082
+ def guidance_rescale(self):
1083
+ return self._guidance_rescale
1084
+
1085
+ @property
1086
+ def clip_skip(self):
1087
+ return self._clip_skip
1088
+
1089
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1090
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1091
+ # corresponds to doing no classifier free guidance.
1092
+ @property
1093
+ def do_classifier_free_guidance(self):
1094
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1095
+
1096
+ @property
1097
+ def cross_attention_kwargs(self):
1098
+ return self._cross_attention_kwargs
1099
+
1100
+ @property
1101
+ def num_timesteps(self):
1102
+ return self._num_timesteps
1103
+
1104
+ @property
1105
+ def interrupt(self):
1106
+ return self._interrupt
1107
+
1108
+ @property
1109
+ def pag_scale(self):
1110
+ return self._pag_scale
1111
+
1112
+ @property
1113
+ def do_adversarial_guidance(self):
1114
+ return self._pag_scale > 0
1115
+
1116
+ @property
1117
+ def pag_adaptive_scaling(self):
1118
+ return self._pag_adaptive_scaling
1119
+
1120
+ @property
1121
+ def do_pag_adaptive_scaling(self):
1122
+ return self._pag_adaptive_scaling > 0
1123
+
1124
+ @property
1125
+ def pag_drop_rate(self):
1126
+ return self._pag_drop_rate
1127
+
1128
+ @property
1129
+ def pag_applied_layers(self):
1130
+ return self._pag_applied_layers
1131
+
1132
+ @property
1133
+ def pag_applied_layers_index(self):
1134
+ return self._pag_applied_layers_index
1135
+
1136
+ @torch.no_grad()
1137
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1138
+ def __call__(
1139
+ self,
1140
+ prompt: Union[str, List[str]] = None,
1141
+ image: PipelineImageInput = None,
1142
+ strength: float = 0.8,
1143
+ num_inference_steps: int = 50,
1144
+ timesteps: List[int] = None,
1145
+ guidance_scale: float = 7.5,
1146
+ pag_scale: float = 0.0,
1147
+ pag_adaptive_scaling: float = 0.0,
1148
+ pag_drop_rate: float = 0.5,
1149
+ pag_applied_layers: List[str] = ["down"], # ['down', 'mid', 'up']
1150
+ pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
1151
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1152
+ num_images_per_prompt: Optional[int] = 1,
1153
+ eta: float = 0.0,
1154
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1155
+ # latents: Optional[torch.FloatTensor] = None,
1156
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1157
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1158
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1159
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1160
+ output_type: Optional[str] = "pil",
1161
+ return_dict: bool = True,
1162
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1163
+ guidance_rescale: float = 0.0,
1164
+ clip_skip: Optional[int] = None,
1165
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1166
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1167
+ **kwargs,
1168
+ ):
1169
+ r"""
1170
+ The call function to the pipeline for generation.
1171
+
1172
+ Args:
1173
+ prompt (`str` or `List[str]`, *optional*):
1174
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1175
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
1176
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
1177
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
1178
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
1179
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
1180
+ latents as `image`, but if passing latents directly it is not encoded again.
1181
+ strength (`float`, *optional*, defaults to 0.8):
1182
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
1183
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
1184
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
1185
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
1186
+ essentially ignores `image`.
1187
+ num_inference_steps (`int`, *optional*, defaults to 50):
1188
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1189
+ expense of slower inference.
1190
+ timesteps (`List[int]`, *optional*):
1191
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1192
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1193
+ passed will be used. Must be in descending order.
1194
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1195
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1196
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1197
+ negative_prompt (`str` or `List[str]`, *optional*):
1198
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1199
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1200
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1201
+ The number of images to generate per prompt.
1202
+ eta (`float`, *optional*, defaults to 0.0):
1203
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1204
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1205
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1206
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1207
+ generation deterministic.
1208
+ latents (`torch.FloatTensor`, *optional*):
1209
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1210
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1211
+ tensor is generated by sampling using the supplied random `generator`.
1212
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1213
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1214
+ provided, text embeddings are generated from the `prompt` input argument.
1215
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1216
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1217
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1218
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1219
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1220
+ Pre-generated image embeddings for IP-Adapter. If not
1221
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1222
+ output_type (`str`, *optional*, defaults to `"pil"`):
1223
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1224
+ return_dict (`bool`, *optional*, defaults to `True`):
1225
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1226
+ plain tuple.
1227
+ cross_attention_kwargs (`dict`, *optional*):
1228
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1229
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1230
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1231
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
1232
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
1233
+ using zero terminal SNR.
1234
+ clip_skip (`int`, *optional*):
1235
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1236
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1237
+ callback_on_step_end (`Callable`, *optional*):
1238
+ A function that calls at the end of each denoising steps during the inference. The function is called
1239
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1240
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1241
+ `callback_on_step_end_tensor_inputs`.
1242
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1243
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1244
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1245
+ `._callback_tensor_inputs` attribute of your pipeline class.
1246
+ Examples:
1247
+ Returns:
1248
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1249
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1250
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
1251
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
1252
+ "not-safe-for-work" (nsfw) content.
1253
+ """
1254
+
1255
+ callback = kwargs.pop("callback", None)
1256
+ callback_steps = kwargs.pop("callback_steps", None)
1257
+
1258
+ if callback is not None:
1259
+ deprecate(
1260
+ "callback",
1261
+ "1.0.0",
1262
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1263
+ )
1264
+ if callback_steps is not None:
1265
+ deprecate(
1266
+ "callback_steps",
1267
+ "1.0.0",
1268
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1269
+ )
1270
+
1271
+ # 0. Default height and width to unet
1272
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1273
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1274
+ # to deal with lora scaling and other possible forward hooks
1275
+
1276
+ # 1. Check inputs. Raise error if not correct
1277
+ self.check_inputs(
1278
+ prompt,
1279
+ strength,
1280
+ callback_steps,
1281
+ negative_prompt,
1282
+ prompt_embeds,
1283
+ negative_prompt_embeds,
1284
+ ip_adapter_image,
1285
+ ip_adapter_image_embeds,
1286
+ callback_on_step_end_tensor_inputs,
1287
+ )
1288
+
1289
+ self._guidance_scale = guidance_scale
1290
+ self._guidance_rescale = guidance_rescale
1291
+ self._clip_skip = clip_skip
1292
+ self._cross_attention_kwargs = cross_attention_kwargs
1293
+ self._interrupt = False
1294
+
1295
+ self._pag_scale = pag_scale
1296
+ self._pag_adaptive_scaling = pag_adaptive_scaling
1297
+ self._pag_drop_rate = pag_drop_rate
1298
+ self._pag_applied_layers = pag_applied_layers
1299
+ self._pag_applied_layers_index = pag_applied_layers_index
1300
+
1301
+ # 2. Define call parameters
1302
+ if prompt is not None and isinstance(prompt, str):
1303
+ batch_size = 1
1304
+ elif prompt is not None and isinstance(prompt, list):
1305
+ batch_size = len(prompt)
1306
+ else:
1307
+ batch_size = prompt_embeds.shape[0]
1308
+
1309
+ device = self._execution_device
1310
+
1311
+ # 3. Encode input prompt
1312
+ lora_scale = (
1313
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1314
+ )
1315
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1316
+ prompt,
1317
+ device,
1318
+ num_images_per_prompt,
1319
+ self.do_classifier_free_guidance,
1320
+ negative_prompt,
1321
+ prompt_embeds=prompt_embeds,
1322
+ negative_prompt_embeds=negative_prompt_embeds,
1323
+ lora_scale=lora_scale,
1324
+ clip_skip=self.clip_skip,
1325
+ )
1326
+
1327
+ # For classifier free guidance, we need to do two forward passes.
1328
+ # Here we concatenate the unconditional and text embeddings into a single batch
1329
+ # to avoid doing two forward passes
1330
+
1331
+ # cfg
1332
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1333
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1334
+ # pag
1335
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1336
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
1337
+ # both
1338
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1339
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
1340
+
1341
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1342
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1343
+ ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
1344
+ )
1345
+
1346
+ # 4. Preprocess image
1347
+ image = self.image_processor.preprocess(image)
1348
+
1349
+ # 5.set timesteps
1350
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1351
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1352
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1353
+
1354
+ # 6. Prepare latent variables
1355
+ latents = self.prepare_latents(
1356
+ image,
1357
+ latent_timestep,
1358
+ batch_size,
1359
+ num_images_per_prompt,
1360
+ prompt_embeds.dtype,
1361
+ device,
1362
+ generator,
1363
+ )
1364
+
1365
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1366
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1367
+
1368
+ # 7.1 Add image embeds for IP-Adapter
1369
+ added_cond_kwargs = (
1370
+ {"image_embeds": image_embeds}
1371
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
1372
+ else None
1373
+ )
1374
+
1375
+ # 7.2 Optionally get Guidance Scale Embedding
1376
+ timestep_cond = None
1377
+ if self.unet.config.time_cond_proj_dim is not None:
1378
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1379
+ timestep_cond = self.get_guidance_scale_embedding(
1380
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1381
+ ).to(device=device, dtype=latents.dtype)
1382
+
1383
+ # 8. Denoising loop
1384
+ if self.do_adversarial_guidance:
1385
+ down_layers = []
1386
+ mid_layers = []
1387
+ up_layers = []
1388
+ for name, module in self.unet.named_modules():
1389
+ if "attn1" in name and "to" not in name:
1390
+ layer_type = name.split(".")[0].split("_")[0]
1391
+ if layer_type == "down":
1392
+ down_layers.append(module)
1393
+ elif layer_type == "mid":
1394
+ mid_layers.append(module)
1395
+ elif layer_type == "up":
1396
+ up_layers.append(module)
1397
+ else:
1398
+ raise ValueError(f"Invalid layer type: {layer_type}")
1399
+
1400
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1401
+ self._num_timesteps = len(timesteps)
1402
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1403
+ for i, t in enumerate(timesteps):
1404
+ if self.interrupt:
1405
+ continue
1406
+
1407
+ # cfg
1408
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1409
+ latent_model_input = torch.cat([latents] * 2)
1410
+ # pag
1411
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1412
+ latent_model_input = torch.cat([latents] * 2)
1413
+ # both
1414
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1415
+ latent_model_input = torch.cat([latents] * 3)
1416
+ # no
1417
+ else:
1418
+ latent_model_input = latents
1419
+
1420
+ # change attention layer in UNet if use PAG
1421
+ if self.do_adversarial_guidance:
1422
+ if self.do_classifier_free_guidance:
1423
+ replace_processor = PAGCFGIdentitySelfAttnProcessor()
1424
+ else:
1425
+ replace_processor = PAGIdentitySelfAttnProcessor()
1426
+
1427
+ drop_layers = self.pag_applied_layers_index
1428
+ for drop_layer in drop_layers:
1429
+ try:
1430
+ if drop_layer[0] == "d":
1431
+ down_layers[int(drop_layer[1])].processor = replace_processor
1432
+ elif drop_layer[0] == "m":
1433
+ mid_layers[int(drop_layer[1])].processor = replace_processor
1434
+ elif drop_layer[0] == "u":
1435
+ up_layers[int(drop_layer[1])].processor = replace_processor
1436
+ else:
1437
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1438
+ except IndexError:
1439
+ raise ValueError(
1440
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1441
+ )
1442
+
1443
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1444
+
1445
+ # predict the noise residual
1446
+ noise_pred = self.unet(
1447
+ latent_model_input,
1448
+ t,
1449
+ encoder_hidden_states=prompt_embeds,
1450
+ timestep_cond=timestep_cond,
1451
+ cross_attention_kwargs=self.cross_attention_kwargs,
1452
+ added_cond_kwargs=added_cond_kwargs,
1453
+ return_dict=False,
1454
+ )[0]
1455
+
1456
+ # perform guidance
1457
+
1458
+ # cfg
1459
+ if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
1460
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1461
+
1462
+ delta = noise_pred_text - noise_pred_uncond
1463
+ noise_pred = noise_pred_uncond + self.guidance_scale * delta
1464
+
1465
+ # pag
1466
+ elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
1467
+ noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
1468
+
1469
+ signal_scale = self.pag_scale
1470
+ if self.do_pag_adaptive_scaling:
1471
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
1472
+ if signal_scale < 0:
1473
+ signal_scale = 0
1474
+
1475
+ noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
1476
+
1477
+ # both
1478
+ elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
1479
+ noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
1480
+
1481
+ signal_scale = self.pag_scale
1482
+ if self.do_pag_adaptive_scaling:
1483
+ signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)
1484
+ if signal_scale < 0:
1485
+ signal_scale = 0
1486
+
1487
+ noise_pred = (
1488
+ noise_pred_text
1489
+ + (self.guidance_scale - 1.0) * (noise_pred_text - noise_pred_uncond)
1490
+ + signal_scale * (noise_pred_text - noise_pred_text_perturb)
1491
+ )
1492
+
1493
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1494
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1495
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1496
+
1497
+ # compute the previous noisy sample x_t -> x_t-1
1498
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1499
+
1500
+ if callback_on_step_end is not None:
1501
+ callback_kwargs = {}
1502
+ for k in callback_on_step_end_tensor_inputs:
1503
+ callback_kwargs[k] = locals()[k]
1504
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1505
+
1506
+ latents = callback_outputs.pop("latents", latents)
1507
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1508
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1509
+
1510
+ # call the callback, if provided
1511
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1512
+ progress_bar.update()
1513
+ if callback is not None and i % callback_steps == 0:
1514
+ step_idx = i // getattr(self.scheduler, "order", 1)
1515
+ callback(step_idx, t, latents)
1516
+
1517
+ if not output_type == "latent":
1518
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1519
+ 0
1520
+ ]
1521
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1522
+ else:
1523
+ image = latents
1524
+ has_nsfw_concept = None
1525
+
1526
+ if has_nsfw_concept is None:
1527
+ do_denormalize = [True] * image.shape[0]
1528
+ else:
1529
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1530
+
1531
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1532
+
1533
+ # Offload all models
1534
+ self.maybe_free_model_hooks()
1535
+
1536
+ if not return_dict:
1537
+ return (image, has_nsfw_concept)
1538
+
1539
+ # change attention layer in UNet if use PAG
1540
+ if self.do_adversarial_guidance:
1541
+ drop_layers = self.pag_applied_layers_index
1542
+ for drop_layer in drop_layers:
1543
+ try:
1544
+ if drop_layer[0] == "d":
1545
+ down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1546
+ elif drop_layer[0] == "m":
1547
+ mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1548
+ elif drop_layer[0] == "u":
1549
+ up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()
1550
+ else:
1551
+ raise ValueError(f"Invalid layer type: {drop_layer[0]}")
1552
+ except IndexError:
1553
+ raise ValueError(
1554
+ f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
1555
+ )
1556
+
1557
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)