AlekseyCalvin commited on
Commit
2e6e5d5
1 Parent(s): 1d6281e

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +674 -0
pipeline.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import html
4
+ import inspect
5
+ import re
6
+ import urllib.parse as ul
7
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
9
+ from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline, ImagePipelineOutput
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
12
+ from diffusers.utils import (
13
+ USE_PEFT_BACKEND,
14
+ is_torch_xla_available,
15
+ logging,
16
+ BACKENDS_MAPPING,
17
+ deprecate,
18
+ replace_example_docstring,
19
+ scale_lora_layers,
20
+ unscale_lora_layers,
21
+ )
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
24
+ from typing import Any, Callable, Dict, List, Optional, Union
25
+ from PIL import Image
26
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
27
+ from diffusers.utils import is_torch_xla_available
28
+
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ # Constants for shift calculation
39
+ BASE_SEQ_LEN = 256
40
+ MAX_SEQ_LEN = 4096
41
+ BASE_SHIFT = 0.5
42
+ MAX_SHIFT = 1.2
43
+
44
+ # Helper functions
45
+ def calculate_timestep_shift(image_seq_len: int) -> float:
46
+ """Calculates the timestep shift (mu) based on the image sequence length."""
47
+ m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
48
+ b = BASE_SHIFT - m * BASE_SEQ_LEN
49
+ mu = image_seq_len * m + b
50
+ return mu
51
+
52
+ def prepare_timesteps(
53
+ scheduler: FlowMatchEulerDiscreteScheduler,
54
+ num_inference_steps: Optional[int] = None,
55
+ device: Optional[Union[str, torch.device]] = None,
56
+ timesteps: Optional[List[int]] = None,
57
+ sigmas: Optional[List[float]] = None,
58
+ mu: Optional[float] = None,
59
+ ) -> (torch.Tensor, int):
60
+ """Prepares the timesteps for the diffusion process."""
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
63
+
64
+ if timesteps is not None:
65
+ scheduler.set_timesteps(timesteps=timesteps, device=device)
66
+ elif sigmas is not None:
67
+ scheduler.set_timesteps(sigmas=sigmas, device=device)
68
+ else:
69
+ scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
70
+
71
+ timesteps = scheduler.timesteps
72
+ num_inference_steps = len(timesteps)
73
+ return timesteps, num_inference_steps
74
+
75
+ # FLUX pipeline function
76
+ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
77
+ def __init__(
78
+ self,
79
+ scheduler: FlowMatchEulerDiscreteScheduler,
80
+ vae: AutoencoderKL,
81
+ text_encoder: CLIPTextModel,
82
+ tokenizer: CLIPTokenizer,
83
+ text_encoder_2: T5EncoderModel,
84
+ tokenizer_2: T5TokenizerFast,
85
+ transformer: FluxTransformer2DModel,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.register_modules(
90
+ vae=vae,
91
+ text_encoder=text_encoder,
92
+ text_encoder_2=text_encoder_2,
93
+ tokenizer=tokenizer,
94
+ tokenizer_2=tokenizer_2,
95
+ transformer=transformer,
96
+ scheduler=scheduler,
97
+ )
98
+ self.vae_scale_factor = (
99
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
100
+ )
101
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
102
+ self.tokenizer_max_length = (
103
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
104
+ )
105
+ self.default_sample_size = 64
106
+
107
+ def _get_t5_prompt_embeds(
108
+ self,
109
+ prompt: Union[str, List[str]] = None,
110
+ num_images_per_prompt: int = 1,
111
+ max_sequence_length: int = 512,
112
+ device: Optional[torch.device] = None,
113
+ dtype: Optional[torch.dtype] = None,
114
+ ):
115
+ device = device or self._execution_device
116
+ dtype = dtype or self.text_encoder.dtype
117
+
118
+ prompt = [prompt] if isinstance(prompt, str) else prompt
119
+ batch_size = len(prompt)
120
+
121
+ text_inputs = self.tokenizer_2(
122
+ prompt,
123
+ padding="max_length",
124
+ max_length=max_sequence_length,
125
+ truncation=True,
126
+ return_length=True,
127
+ return_overflowing_tokens=True,
128
+ return_tensors="pt",
129
+ )
130
+ text_input_ids = text_inputs.input_ids
131
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
132
+
133
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
134
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
135
+ logger.warning(
136
+ "The following part of your input was truncated because `max_sequence_length` is set to "
137
+ f" {max_sequence_length} tokens: {removed_text}"
138
+ )
139
+
140
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
141
+
142
+ dtype = self.text_encoder_2.dtype
143
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
144
+
145
+ _, seq_len, _ = prompt_embeds.shape
146
+
147
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
148
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
149
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
150
+
151
+ return prompt_embeds
152
+
153
+ def _get_clip_prompt_embeds(
154
+ self,
155
+ prompt: Union[str, List[str]],
156
+ num_images_per_prompt: int = 1,
157
+ max_sequence_length: int = 512,
158
+ device: Optional[torch.device] = None,
159
+ ):
160
+ device = device or self._execution_device
161
+
162
+ prompt = [prompt] if isinstance(prompt, str) else prompt
163
+ batch_size = len(prompt)
164
+
165
+ text_inputs = self.tokenizer(
166
+ prompt,
167
+ padding="max_length",
168
+ max_length=self.tokenizer_max_length,
169
+ truncation=True,
170
+ return_overflowing_tokens=False,
171
+ return_length=False,
172
+ return_tensors="pt",
173
+ )
174
+
175
+ text_input_ids = text_inputs.input_ids
176
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
177
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
178
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
179
+ logger.warning(
180
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
181
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
182
+ )
183
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
184
+
185
+ _, seq_len, _ = prompt_embeds.shape
186
+
187
+ # Use pooled output of CLIPTextModel
188
+ prompt_embeds = prompt_embeds.pooler_output
189
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
190
+
191
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
192
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
193
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
194
+
195
+ return prompt_embeds
196
+
197
+ def encode_prompt(
198
+ self,
199
+ prompt: Union[str, List[str]],
200
+ prompt_2: Union[str, List[str]],
201
+ num_images_per_prompt: int = 1,
202
+ max_sequence_length: int = 512,
203
+ do_classifier_free_guidance: bool = True,
204
+ device: Optional[torch.device] = None,
205
+ negative_prompt: Optional[Union[str, List[str]]] = None,
206
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
207
+ prompt_embeds: Optional[torch.FloatTensor] = None,
208
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
209
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
210
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
211
+ lora_scale: Optional[float] = None,
212
+ ):
213
+ device = device or self._execution_device
214
+ if device is None:
215
+ device = self._execution_device
216
+
217
+ # set lora scale so that monkey patched LoRA
218
+ # function of text encoder can correctly access it
219
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
220
+ self._lora_scale = lora_scale
221
+
222
+ # dynamically adjust the LoRA scale
223
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
224
+ scale_lora_layers(self.text_encoder, lora_scale)
225
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
226
+ scale_lora_layers(self.text_encoder_2, lora_scale)
227
+
228
+ prompt = [prompt] if isinstance(prompt, str) else prompt
229
+ if prompt is not None:
230
+ batch_size = len(prompt)
231
+ else:
232
+ batch_size = prompt_embeds.shape[0]
233
+
234
+ if prompt_embeds is None:
235
+ prompt_2 = prompt_2 or prompt
236
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
237
+
238
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
239
+ prompt=prompt,
240
+ device=device,
241
+ num_images_per_prompt=num_images_per_prompt,
242
+ )
243
+ prompt_embeds = self._get_t5_prompt_embeds(
244
+ prompt=prompt_2,
245
+ num_images_per_prompt=num_images_per_prompt,
246
+ max_sequence_length=max_sequence_length,
247
+ device=device,
248
+ )
249
+
250
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
251
+ negative_prompt = negative_prompt or ""
252
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
253
+
254
+ # normalize str to list
255
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
256
+ negative_prompt_2 = (
257
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
258
+ )
259
+
260
+ if prompt is not None and type(prompt) is not type(negative_prompt):
261
+ raise TypeError(
262
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
263
+ f" {type(prompt)}."
264
+ )
265
+ elif batch_size != len(negative_prompt):
266
+ raise ValueError(
267
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
268
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
269
+ " the batch size of `prompt`."
270
+ )
271
+
272
+ negative_clip_prompt_embed = self._get_clip_prompt_embeds(
273
+ prompt=negative_prompt,
274
+ device=device,
275
+ num_images_per_prompt=num_images_per_prompt,
276
+ max_sequence_length=max_sequence_length,
277
+ )
278
+
279
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
280
+ prompt=negative_prompt_2,
281
+ device=device,
282
+ num_images_per_prompt=num_images_per_prompt,
283
+ max_sequence_length=max_sequence_length,
284
+ )
285
+
286
+ negative_clip_prompt_embed = torch.nn.functional.pad(
287
+ negative_clip_prompt_embed,
288
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embed.shape[-1]),
289
+ )
290
+
291
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
292
+
293
+
294
+ negative_pooled_prompt_embeds = torch.cat(
295
+ [negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-1
296
+ )
297
+
298
+ if self.text_encoder is not None:
299
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
300
+ # Retrieve the original scale by scaling back the LoRA layers
301
+ unscale_lora_layers(self.text_encoder, lora_scale)
302
+
303
+ if self.text_encoder_2 is not None:
304
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
305
+ # Retrieve the original scale by scaling back the LoRA layers
306
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
307
+
308
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
309
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
310
+
311
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
312
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
313
+
314
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
315
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
316
+
317
+ return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
318
+
319
+ def check_inputs(
320
+ self,
321
+ prompt,
322
+ prompt_2,
323
+ height,
324
+ width,
325
+ negative_prompt=None,
326
+ negative_prompt_2=None,
327
+ prompt_embeds=None,
328
+ negative_prompt_embeds=None,
329
+ pooled_prompt_embeds=None,
330
+ negative_pooled_prompt_embeds=None,
331
+ max_sequence_length=None,
332
+ ):
333
+ if height % 8 != 0 or width % 8 != 0:
334
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
335
+
336
+ if prompt is not None and prompt_embeds is not None:
337
+ raise ValueError(
338
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
339
+ " only forward one of the two."
340
+ )
341
+ elif prompt_2 is not None and prompt_embeds is not None:
342
+ raise ValueError(
343
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
344
+ " only forward one of the two."
345
+ )
346
+ elif prompt is None and prompt_embeds is None:
347
+ raise ValueError(
348
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
349
+ )
350
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
351
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
352
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
353
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
354
+
355
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
356
+ raise ValueError(
357
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
358
+ )
359
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
360
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
361
+
362
+ if max_sequence_length is not None and max_sequence_length > 512:
363
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
364
+
365
+ @staticmethod
366
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
367
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
368
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
369
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
370
+
371
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
372
+
373
+ latent_image_ids = latent_image_ids.reshape(
374
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
375
+ )
376
+
377
+ return latent_image_ids.to(device=device, dtype=dtype)
378
+
379
+ @staticmethod
380
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
381
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
382
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
383
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
384
+
385
+ return latents
386
+
387
+ @staticmethod
388
+ def _unpack_latents(latents, height, width, vae_scale_factor):
389
+ batch_size, num_patches, channels = latents.shape
390
+
391
+ height = height // vae_scale_factor
392
+ width = width // vae_scale_factor
393
+
394
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
395
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
396
+
397
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
398
+
399
+ return latents
400
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
401
+ def prepare_extra_step_kwargs(self, generator, eta):
402
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
403
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
404
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
405
+ # and should be between [0, 1]
406
+
407
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
408
+ extra_step_kwargs = {}
409
+ if accepts_eta:
410
+ extra_step_kwargs["eta"] = eta
411
+
412
+ # check if the scheduler accepts generator
413
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
414
+ if accepts_generator:
415
+ extra_step_kwargs["generator"] = generator
416
+ return extra_step_kwargs
417
+
418
+ def enable_vae_slicing(self):
419
+ r"""
420
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
421
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
422
+ """
423
+ self.vae.enable_slicing()
424
+
425
+ def disable_vae_slicing(self):
426
+ r"""
427
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
428
+ computing decoding in one step.
429
+ """
430
+ self.vae.disable_slicing()
431
+
432
+ def enable_vae_tiling(self):
433
+ r"""
434
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
435
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
436
+ processing larger images.
437
+ """
438
+ self.vae.enable_tiling()
439
+
440
+ def disable_vae_tiling(self):
441
+ r"""
442
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
443
+ computing decoding in one step.
444
+ """
445
+ self.vae.disable_tiling()
446
+
447
+ def prepare_latents(
448
+ self,
449
+ batch_size,
450
+ num_channels_latents,
451
+ height,
452
+ width,
453
+ dtype,
454
+ device,
455
+ generator,
456
+ latents=None,
457
+ ):
458
+ height = 2 * (int(height) // self.vae_scale_factor)
459
+ width = 2 * (int(width) // self.vae_scale_factor)
460
+
461
+ shape = (batch_size, num_channels_latents, height, width)
462
+
463
+ if latents is not None:
464
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
465
+ return latents.to(device=device, dtype=dtype), latent_image_ids
466
+
467
+ if isinstance(generator, list) and len(generator) != batch_size:
468
+ raise ValueError(
469
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
470
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
471
+ )
472
+
473
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
474
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
475
+
476
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
477
+
478
+ return latents, latent_image_ids
479
+
480
+ @property
481
+ def guidance_scale(self):
482
+ return self._guidance_scale
483
+
484
+ @property
485
+ def do_classifier_free_guidance(self):
486
+ return self._guidance_scale > 1
487
+
488
+ @property
489
+ def joint_attention_kwargs(self):
490
+ return self._joint_attention_kwargs
491
+
492
+ @property
493
+ def num_timesteps(self):
494
+ return self._num_timesteps
495
+
496
+ @property
497
+ def interrupt(self):
498
+ return self._interrupt
499
+
500
+ @torch.no_grad()
501
+ @torch.inference_mode()
502
+ def __call__(
503
+ self,
504
+ prompt: Union[str, List[str]] = None,
505
+ prompt_2: Optional[Union[str, List[str]]] = None,
506
+ height: Optional[int] = None,
507
+ width: Optional[int] = None,
508
+ negative_prompt: Optional[Union[str, List[str]]] = None,
509
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
510
+ num_inference_steps: int = 8,
511
+ timesteps: List[int] = None,
512
+ eta: Optional[float] = 0.0,
513
+ guidance_scale: float = 3.5,
514
+ device: Optional[int] = None,
515
+ num_images_per_prompt: Optional[int] = 1,
516
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
517
+ latents: Optional[torch.FloatTensor] = None,
518
+ prompt_embeds: Optional[torch.FloatTensor] = None,
519
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
520
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
521
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
522
+ output_type: Optional[str] = "pil",
523
+ return_dict: bool = True,
524
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
525
+ max_sequence_length: int = 300,
526
+ **kwargs,
527
+ ):
528
+ height = height or self.default_sample_size * self.vae_scale_factor
529
+ width = width or self.default_sample_size * self.vae_scale_factor
530
+
531
+ # 1. Check inputs
532
+ self.check_inputs(
533
+ prompt,
534
+ prompt_2,
535
+ height,
536
+ width,
537
+ negative_prompt=negative_prompt,
538
+ negative_prompt_2=negative_prompt_2,
539
+ prompt_embeds=prompt_embeds,
540
+ negative_prompt_embeds=negative_prompt_embeds,
541
+ pooled_prompt_embeds=pooled_prompt_embeds,
542
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
543
+ max_sequence_length=max_sequence_length,
544
+ )
545
+
546
+ self._guidance_scale = guidance_scale
547
+ self._joint_attention_kwargs = joint_attention_kwargs
548
+ self._interrupt = False
549
+
550
+ do_classifier_free_guidance = guidance_scale > 1.0
551
+
552
+ # 2. Define call parameters
553
+ if prompt is not None and isinstance(prompt, str):
554
+ batch_size = 1
555
+ elif prompt is not None and isinstance(prompt, list):
556
+ batch_size = len(prompt)
557
+ else:
558
+ batch_size = prompt_embeds.shape[0]
559
+
560
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
561
+
562
+ lora_scale = (
563
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
564
+ )
565
+ (
566
+ prompt_embeds,
567
+ negative_prompt_embeds,
568
+ pooled_prompt_embeds,
569
+ negative_pooled_prompt_embeds,
570
+ ) = self.encode_prompt(
571
+ prompt=prompt,
572
+ prompt_2=prompt_2,
573
+ negative_prompt=negative_prompt,
574
+ negative_prompt_2=negative_prompt_2,
575
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
576
+ prompt_embeds=prompt_embeds,
577
+ pooled_prompt_embeds=pooled_prompt_embeds,
578
+ negative_prompt_embeds=negative_prompt_embeds,
579
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
580
+ device=device,
581
+ num_images_per_prompt=num_images_per_prompt,
582
+ max_sequence_length=max_sequence_length,
583
+ lora_scale=lora_scale,
584
+ )
585
+
586
+ if self.do_classifier_free_guidance:
587
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
588
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
589
+
590
+ # 4. Prepare latent variables
591
+ num_channels_latents = self.transformer.config.in_channels // 4
592
+ latents, latent_image_ids = self.prepare_latents(
593
+ batch_size * num_images_per_prompt,
594
+ num_channels_latents,
595
+ height,
596
+ width,
597
+ prompt_embeds.dtype,
598
+ negative_prompt_embeds.dtype,
599
+ device,
600
+ generator,
601
+ latents,
602
+ )
603
+
604
+ # 5. Prepare timesteps
605
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
606
+ image_seq_len = latents.shape[1]
607
+ mu = calculate_timestep_shift(image_seq_len)
608
+ timesteps, num_inference_steps = prepare_timesteps(
609
+ self.scheduler,
610
+ num_inference_steps,
611
+ device,
612
+ timesteps,
613
+ sigmas,
614
+ mu=mu,
615
+ )
616
+ self._num_timesteps = len(timesteps)
617
+
618
+ # 6. Denoising loop
619
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
620
+ for i, t in enumerate(timesteps):
621
+ if self.interrupt:
622
+ continue
623
+
624
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
625
+
626
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
627
+
628
+ if self.transformer.config.guidance_embeds:
629
+ guidance = torch.tensor([guidance_scale], device=device)
630
+ guidance = guidance.expand(latents.shape[0])
631
+ else:
632
+ guidance = None
633
+
634
+ noise_pred = self.transformer(
635
+ hidden_states=latent_model_input,
636
+ timestep=timestep / 1000,
637
+ guidance=guidance,
638
+ pooled_projections=pooled_prompt_embeds,
639
+ encoder_hidden_states=prompt_embeds,
640
+ txt_ids=text_ids,
641
+ img_ids=latent_image_ids,
642
+ joint_attention_kwargs=self.joint_attention_kwargs,
643
+ return_dict=False,
644
+ )[0]
645
+
646
+ if self.do_classifier_free_guidance:
647
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
648
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
649
+
650
+ # compute the previous noisy sample x_t -> x_t-1
651
+ latents_dtype = latents.dtype
652
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
653
+
654
+ if latents.dtype != latents_dtype:
655
+ if torch.backends.mps.is_available():
656
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
657
+ latents = latents.to(latents_dtype)
658
+
659
+ # call the callback, if provided
660
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
661
+ progress_bar.update()
662
+
663
+ # Final image
664
+ return self._decode_latents_to_image(latents, height, width, output_type)
665
+ self.maybe_free_model_hooks()
666
+ torch.cuda.empty_cache()
667
+
668
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
669
+ """Decodes the given latents into an image."""
670
+ vae = vae or self.vae
671
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
672
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
673
+ image = vae.decode(latents, return_dict=False)[0]
674
+ return self.image_processor.postprocess(image, output_type=output_type)[0]