AlanB commited on
Commit
dfdee9a
1 Parent(s): d77c8ff

Added Callback steps for my progressbar in StableDiffusionDeluxe

Browse files
Files changed (1) hide show
  1. pipeline.py +595 -0
pipeline.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Tuple, Union, Callable
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
7
+ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
8
+
9
+ from diffusers import (
10
+ DiffusionPipeline,
11
+ ImagePipelineOutput,
12
+ PriorTransformer,
13
+ UnCLIPScheduler,
14
+ UNet2DConditionModel,
15
+ UNet2DModel,
16
+ )
17
+ from diffusers.pipelines.unclip import UnCLIPTextProjModel
18
+ from diffusers.utils import is_accelerate_available, logging, randn_tensor
19
+
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+
24
+ def slerp(val, low, high):
25
+ """
26
+ Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
27
+ """
28
+ low_norm = low / torch.norm(low)
29
+ high_norm = high / torch.norm(high)
30
+ omega = torch.acos((low_norm * high_norm))
31
+ so = torch.sin(omega)
32
+ res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
33
+ return res
34
+
35
+
36
+ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
37
+
38
+ """
39
+ Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images.
40
+
41
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
42
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
43
+
44
+ Args:
45
+ text_encoder ([`CLIPTextModelWithProjection`]):
46
+ Frozen text-encoder.
47
+ tokenizer (`CLIPTokenizer`):
48
+ Tokenizer of class
49
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
50
+ prior ([`PriorTransformer`]):
51
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
52
+ text_proj ([`UnCLIPTextProjModel`]):
53
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
54
+ decoder ([`UNet2DConditionModel`]):
55
+ The decoder to invert the image embedding into an image.
56
+ super_res_first ([`UNet2DModel`]):
57
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
58
+ super_res_last ([`UNet2DModel`]):
59
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
60
+ prior_scheduler ([`UnCLIPScheduler`]):
61
+ Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
62
+ decoder_scheduler ([`UnCLIPScheduler`]):
63
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
64
+ super_res_scheduler ([`UnCLIPScheduler`]):
65
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
66
+
67
+ """
68
+
69
+ prior: PriorTransformer
70
+ decoder: UNet2DConditionModel
71
+ text_proj: UnCLIPTextProjModel
72
+ text_encoder: CLIPTextModelWithProjection
73
+ tokenizer: CLIPTokenizer
74
+ super_res_first: UNet2DModel
75
+ super_res_last: UNet2DModel
76
+
77
+ prior_scheduler: UnCLIPScheduler
78
+ decoder_scheduler: UnCLIPScheduler
79
+ super_res_scheduler: UnCLIPScheduler
80
+
81
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__
82
+ def __init__(
83
+ self,
84
+ prior: PriorTransformer,
85
+ decoder: UNet2DConditionModel,
86
+ text_encoder: CLIPTextModelWithProjection,
87
+ tokenizer: CLIPTokenizer,
88
+ text_proj: UnCLIPTextProjModel,
89
+ super_res_first: UNet2DModel,
90
+ super_res_last: UNet2DModel,
91
+ prior_scheduler: UnCLIPScheduler,
92
+ decoder_scheduler: UnCLIPScheduler,
93
+ super_res_scheduler: UnCLIPScheduler,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.register_modules(
98
+ prior=prior,
99
+ decoder=decoder,
100
+ text_encoder=text_encoder,
101
+ tokenizer=tokenizer,
102
+ text_proj=text_proj,
103
+ super_res_first=super_res_first,
104
+ super_res_last=super_res_last,
105
+ prior_scheduler=prior_scheduler,
106
+ decoder_scheduler=decoder_scheduler,
107
+ super_res_scheduler=super_res_scheduler,
108
+ )
109
+
110
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
111
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
112
+ if latents is None:
113
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
114
+ else:
115
+ if latents.shape != shape:
116
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
117
+ latents = latents.to(device)
118
+
119
+ latents = latents * scheduler.init_noise_sigma
120
+ return latents
121
+
122
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
123
+ def _encode_prompt(
124
+ self,
125
+ prompt,
126
+ device,
127
+ num_images_per_prompt,
128
+ do_classifier_free_guidance,
129
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
130
+ text_attention_mask: Optional[torch.Tensor] = None,
131
+ ):
132
+ if text_model_output is None:
133
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
134
+ # get prompt text embeddings
135
+ text_inputs = self.tokenizer(
136
+ prompt,
137
+ padding="max_length",
138
+ max_length=self.tokenizer.model_max_length,
139
+ truncation=True,
140
+ return_tensors="pt",
141
+ )
142
+ text_input_ids = text_inputs.input_ids
143
+ text_mask = text_inputs.attention_mask.bool().to(device)
144
+
145
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
146
+
147
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
148
+ text_input_ids, untruncated_ids
149
+ ):
150
+ removed_text = self.tokenizer.batch_decode(
151
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
152
+ )
153
+ logger.warning(
154
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
155
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
156
+ )
157
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
158
+
159
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
160
+
161
+ prompt_embeds = text_encoder_output.text_embeds
162
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
163
+
164
+ else:
165
+ batch_size = text_model_output[0].shape[0]
166
+ prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
167
+ text_mask = text_attention_mask
168
+
169
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
170
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
171
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
172
+
173
+ if do_classifier_free_guidance:
174
+ uncond_tokens = [""] * batch_size
175
+
176
+ uncond_input = self.tokenizer(
177
+ uncond_tokens,
178
+ padding="max_length",
179
+ max_length=self.tokenizer.model_max_length,
180
+ truncation=True,
181
+ return_tensors="pt",
182
+ )
183
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
184
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
185
+
186
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
187
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
188
+
189
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
190
+
191
+ seq_len = negative_prompt_embeds.shape[1]
192
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
193
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
194
+
195
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
196
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
197
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
198
+ batch_size * num_images_per_prompt, seq_len, -1
199
+ )
200
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
201
+
202
+ # done duplicates
203
+
204
+ # For classifier free guidance, we need to do two forward passes.
205
+ # Here we concatenate the unconditional and text embeddings into a single batch
206
+ # to avoid doing two forward passes
207
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
208
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
209
+
210
+ text_mask = torch.cat([uncond_text_mask, text_mask])
211
+
212
+ return prompt_embeds, text_encoder_hidden_states, text_mask
213
+
214
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.enable_sequential_cpu_offload
215
+ def enable_sequential_cpu_offload(self, gpu_id=0):
216
+ r"""
217
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
218
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
219
+ when their specific submodule has its `forward` method called.
220
+ """
221
+ if is_accelerate_available():
222
+ from accelerate import cpu_offload
223
+ else:
224
+ raise ImportError("Please install accelerate via `pip install accelerate`")
225
+
226
+ device = torch.device(f"cuda:{gpu_id}")
227
+
228
+ # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
229
+ models = [
230
+ self.decoder,
231
+ self.text_proj,
232
+ self.text_encoder,
233
+ self.super_res_first,
234
+ self.super_res_last,
235
+ ]
236
+ for cpu_offloaded_model in models:
237
+ if cpu_offloaded_model is not None:
238
+ cpu_offload(cpu_offloaded_model, device)
239
+
240
+ @property
241
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
242
+ def _execution_device(self):
243
+ r"""
244
+ Returns the device on which the pipeline's models will be executed. After calling
245
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
246
+ hooks.
247
+ """
248
+ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
249
+ return self.device
250
+ for module in self.decoder.modules():
251
+ if (
252
+ hasattr(module, "_hf_hook")
253
+ and hasattr(module._hf_hook, "execution_device")
254
+ and module._hf_hook.execution_device is not None
255
+ ):
256
+ return torch.device(module._hf_hook.execution_device)
257
+ return self.device
258
+
259
+ @torch.no_grad()
260
+ def __call__(
261
+ self,
262
+ start_prompt: str,
263
+ end_prompt: str,
264
+ steps: int = 5,
265
+ prior_num_inference_steps: int = 25,
266
+ decoder_num_inference_steps: int = 25,
267
+ super_res_num_inference_steps: int = 7,
268
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
269
+ prior_guidance_scale: float = 4.0,
270
+ decoder_guidance_scale: float = 8.0,
271
+ enable_sequential_cpu_offload=True,
272
+ gpu_id=0,
273
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
274
+ callback_steps: int = 1,
275
+ output_type: Optional[str] = "pil",
276
+ return_dict: bool = True,
277
+ ):
278
+ """
279
+ Function invoked when calling the pipeline for generation.
280
+
281
+ Args:
282
+ start_prompt (`str`):
283
+ The prompt to start the image generation interpolation from.
284
+ end_prompt (`str`):
285
+ The prompt to end the image generation interpolation at.
286
+ steps (`int`, *optional*, defaults to 5):
287
+ The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns
288
+ the same number of images as this value.
289
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
290
+ The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
291
+ image at the expense of slower inference.
292
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
293
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
294
+ image at the expense of slower inference.
295
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
296
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
297
+ quality image at the expense of slower inference.
298
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
299
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
300
+ to make generation deterministic.
301
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
302
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
303
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
304
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
305
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
306
+ usually at the expense of lower image quality.
307
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
308
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
309
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
310
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
311
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
312
+ usually at the expense of lower image quality.
313
+ output_type (`str`, *optional*, defaults to `"pil"`):
314
+ The output format of the generated image. Choose between
315
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
316
+ enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`):
317
+ If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
318
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
319
+ when their specific submodule has its `forward` method called.
320
+ gpu_id (`int`, *optional*, defaults to `0`):
321
+ The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True.
322
+ return_dict (`bool`, *optional*, defaults to `True`):
323
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
324
+ callback (`Callable`, *optional*):
325
+ A function that will be called every `callback_steps` steps during inference. The function will be
326
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
327
+ callback_steps (`int`, *optional*, defaults to 1):
328
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
329
+ called at every step.
330
+ """
331
+
332
+ if not isinstance(start_prompt, str) or not isinstance(end_prompt, str):
333
+ raise ValueError(
334
+ f"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and"
335
+ f" {type(end_prompt)} instead"
336
+ )
337
+
338
+ if enable_sequential_cpu_offload:
339
+ self.enable_sequential_cpu_offload(gpu_id=gpu_id)
340
+
341
+ device = self._execution_device
342
+
343
+ # Turn the prompts into embeddings.
344
+ inputs = self.tokenizer(
345
+ [start_prompt, end_prompt],
346
+ padding="max_length",
347
+ truncation=True,
348
+ max_length=self.tokenizer.model_max_length,
349
+ return_tensors="pt",
350
+ )
351
+ inputs.to(device)
352
+ text_model_output = self.text_encoder(**inputs)
353
+
354
+ text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1])
355
+ text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device)
356
+
357
+ # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline
358
+ batch_text_embeds = []
359
+ batch_last_hidden_state = []
360
+
361
+ for interp_val in torch.linspace(0, 1, steps):
362
+ text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1])
363
+ last_hidden_state = slerp(
364
+ interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1]
365
+ )
366
+ batch_text_embeds.append(text_embeds.unsqueeze(0))
367
+ batch_last_hidden_state.append(last_hidden_state.unsqueeze(0))
368
+
369
+ batch_text_embeds = torch.cat(batch_text_embeds)
370
+ batch_last_hidden_state = torch.cat(batch_last_hidden_state)
371
+
372
+ text_model_output = CLIPTextModelOutput(
373
+ text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state
374
+ )
375
+
376
+ batch_size = text_model_output[0].shape[0]
377
+
378
+ do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
379
+
380
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
381
+ prompt=None,
382
+ device=device,
383
+ num_images_per_prompt=1,
384
+ do_classifier_free_guidance=do_classifier_free_guidance,
385
+ text_model_output=text_model_output,
386
+ text_attention_mask=text_attention_mask,
387
+ )
388
+
389
+ # prior
390
+ current_step = 0
391
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
392
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
393
+
394
+ embedding_dim = self.prior.config.embedding_dim
395
+
396
+ prior_latents = self.prepare_latents(
397
+ (batch_size, embedding_dim),
398
+ prompt_embeds.dtype,
399
+ device,
400
+ generator,
401
+ None,
402
+ self.prior_scheduler,
403
+ )
404
+
405
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
406
+ # expand the latents if we are doing classifier free guidance
407
+ latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
408
+
409
+ predicted_image_embedding = self.prior(
410
+ latent_model_input,
411
+ timestep=t,
412
+ proj_embedding=prompt_embeds,
413
+ encoder_hidden_states=text_encoder_hidden_states,
414
+ attention_mask=text_mask,
415
+ ).predicted_image_embedding
416
+
417
+ if do_classifier_free_guidance:
418
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
419
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
420
+ predicted_image_embedding_text - predicted_image_embedding_uncond
421
+ )
422
+
423
+ if i + 1 == prior_timesteps_tensor.shape[0]:
424
+ prev_timestep = None
425
+ else:
426
+ prev_timestep = prior_timesteps_tensor[i + 1]
427
+
428
+ prior_latents = self.prior_scheduler.step(
429
+ predicted_image_embedding,
430
+ timestep=t,
431
+ sample=prior_latents,
432
+ generator=generator,
433
+ prev_timestep=prev_timestep,
434
+ ).prev_sample
435
+ # call the callback, if provided
436
+ current_step += 1
437
+ if callback is not None and current_step % callback_steps == 0:
438
+ callback(current_step, t, prior_latents)
439
+
440
+ prior_latents = self.prior.post_process_latents(prior_latents)
441
+
442
+ image_embeddings = prior_latents
443
+
444
+ # done prior
445
+
446
+ # decoder
447
+
448
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
449
+ image_embeddings=image_embeddings,
450
+ prompt_embeds=prompt_embeds,
451
+ text_encoder_hidden_states=text_encoder_hidden_states,
452
+ do_classifier_free_guidance=do_classifier_free_guidance,
453
+ )
454
+
455
+ if device.type == "mps":
456
+ # HACK: MPS: There is a panic when padding bool tensors,
457
+ # so cast to int tensor for the pad and back to bool afterwards
458
+ text_mask = text_mask.type(torch.int)
459
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
460
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
461
+ else:
462
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
463
+
464
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
465
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
466
+
467
+ num_channels_latents = self.decoder.in_channels
468
+ height = self.decoder.sample_size
469
+ width = self.decoder.sample_size
470
+
471
+ decoder_latents = self.prepare_latents(
472
+ (batch_size, num_channels_latents, height, width),
473
+ text_encoder_hidden_states.dtype,
474
+ device,
475
+ generator,
476
+ None,
477
+ self.decoder_scheduler,
478
+ )
479
+
480
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
481
+ # expand the latents if we are doing classifier free guidance
482
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
483
+
484
+ noise_pred = self.decoder(
485
+ sample=latent_model_input,
486
+ timestep=t,
487
+ encoder_hidden_states=text_encoder_hidden_states,
488
+ class_labels=additive_clip_time_embeddings,
489
+ attention_mask=decoder_text_mask,
490
+ ).sample
491
+
492
+ if do_classifier_free_guidance:
493
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
494
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
495
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
496
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
497
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
498
+
499
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
500
+ prev_timestep = None
501
+ else:
502
+ prev_timestep = decoder_timesteps_tensor[i + 1]
503
+
504
+ # compute the previous noisy sample x_t -> x_t-1
505
+ decoder_latents = self.decoder_scheduler.step(
506
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
507
+ ).prev_sample
508
+
509
+ # call the callback, if provided
510
+ current_step += 1
511
+ if callback is not None and current_step % callback_steps == 0:
512
+ callback(current_step, t, decoder_latents)
513
+
514
+ decoder_latents = decoder_latents.clamp(-1, 1)
515
+
516
+ image_small = decoder_latents
517
+
518
+ # done decoder
519
+
520
+ # super res
521
+
522
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
523
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
524
+
525
+ channels = self.super_res_first.in_channels // 2
526
+ height = self.super_res_first.sample_size
527
+ width = self.super_res_first.sample_size
528
+
529
+ super_res_latents = self.prepare_latents(
530
+ (batch_size, channels, height, width),
531
+ image_small.dtype,
532
+ device,
533
+ generator,
534
+ None,
535
+ self.super_res_scheduler,
536
+ )
537
+
538
+ if device.type == "mps":
539
+ # MPS does not support many interpolations
540
+ image_upscaled = F.interpolate(image_small, size=[height, width])
541
+ else:
542
+ interpolate_antialias = {}
543
+ if "antialias" in inspect.signature(F.interpolate).parameters:
544
+ interpolate_antialias["antialias"] = True
545
+
546
+ image_upscaled = F.interpolate(
547
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
548
+ )
549
+
550
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
551
+ # no classifier free guidance
552
+
553
+ if i == super_res_timesteps_tensor.shape[0] - 1:
554
+ unet = self.super_res_last
555
+ else:
556
+ unet = self.super_res_first
557
+
558
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
559
+
560
+ noise_pred = unet(
561
+ sample=latent_model_input,
562
+ timestep=t,
563
+ ).sample
564
+
565
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
566
+ prev_timestep = None
567
+ else:
568
+ prev_timestep = super_res_timesteps_tensor[i + 1]
569
+
570
+ # compute the previous noisy sample x_t -> x_t-1
571
+ super_res_latents = self.super_res_scheduler.step(
572
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
573
+ ).prev_sample
574
+
575
+ # call the callback, if provided
576
+ current_step += 1
577
+ if callback is not None and current_step % callback_steps == 0:
578
+ callback(current_step, t, super_res_latents)
579
+
580
+ image = super_res_latents
581
+ # done super res
582
+
583
+ # post processing
584
+
585
+ image = image * 0.5 + 0.5
586
+ image = image.clamp(0, 1)
587
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
588
+
589
+ if output_type == "pil":
590
+ image = self.numpy_to_pil(image)
591
+
592
+ if not return_dict:
593
+ return (image,)
594
+
595
+ return ImagePipelineOutput(images=image)