Linoy Tsaban commited on
Commit
162c70e
1 Parent(s): 11ce2aa

Create modified_pipeline_semantic_stable_diffusion.py

Browse files
modified_pipeline_semantic_stable_diffusion.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import inspect
3
+ import warnings
4
+ from itertools import repeat
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import torch
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
+
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
12
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
13
+ from diffusers.schedulers import KarrasDiffusionSchedulers
14
+ from diffusers.utils import logging, randn_tensor
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
17
+ # from . import SemanticStableDiffusionPipelineOutput
18
+
19
+
20
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
+
22
+
23
+ class SemanticStableDiffusionPipeline(DiffusionPipeline):
24
+ r"""
25
+ Pipeline for text-to-image generation with latent editing.
26
+
27
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
+
30
+ This model builds on the implementation of ['StableDiffusionPipeline']
31
+
32
+ Args:
33
+ vae ([`AutoencoderKL`]):
34
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
35
+ text_encoder ([`CLIPTextModel`]):
36
+ Frozen text-encoder. Stable Diffusion uses the text portion of
37
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
38
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
39
+ tokenizer (`CLIPTokenizer`):
40
+ Tokenizer of class
41
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
42
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
43
+ scheduler ([`SchedulerMixin`]):
44
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
45
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
46
+ safety_checker ([`Q16SafetyChecker`]):
47
+ Classification module that estimates whether generated images could be considered offensive or harmful.
48
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
49
+ feature_extractor ([`CLIPImageProcessor`]):
50
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
51
+ """
52
+
53
+ _optional_components = ["safety_checker", "feature_extractor"]
54
+
55
+ def __init__(
56
+ self,
57
+ vae: AutoencoderKL,
58
+ text_encoder: CLIPTextModel,
59
+ tokenizer: CLIPTokenizer,
60
+ unet: UNet2DConditionModel,
61
+ scheduler: KarrasDiffusionSchedulers,
62
+ safety_checker: StableDiffusionSafetyChecker,
63
+ feature_extractor: CLIPImageProcessor,
64
+ requires_safety_checker: bool = True,
65
+ ):
66
+ super().__init__()
67
+
68
+ if safety_checker is None and requires_safety_checker:
69
+ logger.warning(
70
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
71
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
72
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
73
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
74
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
75
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
76
+ )
77
+
78
+ if safety_checker is not None and feature_extractor is None:
79
+ raise ValueError(
80
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
81
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
82
+ )
83
+
84
+ self.register_modules(
85
+ vae=vae,
86
+ text_encoder=text_encoder,
87
+ tokenizer=tokenizer,
88
+ unet=unet,
89
+ scheduler=scheduler,
90
+ safety_checker=safety_checker,
91
+ feature_extractor=feature_extractor,
92
+ )
93
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
94
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
95
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
96
+
97
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
98
+ def run_safety_checker(self, image, device, dtype):
99
+ if self.safety_checker is None:
100
+ has_nsfw_concept = None
101
+ else:
102
+ if torch.is_tensor(image):
103
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
104
+ else:
105
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
106
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
107
+ image, has_nsfw_concept = self.safety_checker(
108
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
109
+ )
110
+ return image, has_nsfw_concept
111
+
112
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
113
+ def decode_latents(self, latents):
114
+ warnings.warn(
115
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
116
+ " use VaeImageProcessor instead",
117
+ FutureWarning,
118
+ )
119
+ latents = 1 / self.vae.config.scaling_factor * latents
120
+ image = self.vae.decode(latents, return_dict=False)[0]
121
+ image = (image / 2 + 0.5).clamp(0, 1)
122
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
123
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
124
+ return image
125
+
126
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
127
+ def prepare_extra_step_kwargs(self, generator, eta):
128
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
129
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
130
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
131
+ # and should be between [0, 1]
132
+
133
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
134
+ extra_step_kwargs = {}
135
+ if accepts_eta:
136
+ extra_step_kwargs["eta"] = eta
137
+
138
+ # check if the scheduler accepts generator
139
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
140
+ if accepts_generator:
141
+ extra_step_kwargs["generator"] = generator
142
+ return extra_step_kwargs
143
+
144
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
145
+ def check_inputs(
146
+ self,
147
+ prompt,
148
+ height,
149
+ width,
150
+ callback_steps,
151
+ negative_prompt=None,
152
+ prompt_embeds=None,
153
+ negative_prompt_embeds=None,
154
+ ):
155
+ if height % 8 != 0 or width % 8 != 0:
156
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
157
+
158
+ if (callback_steps is None) or (
159
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
160
+ ):
161
+ raise ValueError(
162
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
163
+ f" {type(callback_steps)}."
164
+ )
165
+
166
+ if prompt is not None and prompt_embeds is not None:
167
+ raise ValueError(
168
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
169
+ " only forward one of the two."
170
+ )
171
+ elif prompt is None and prompt_embeds is None:
172
+ raise ValueError(
173
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
174
+ )
175
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
176
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
177
+
178
+ if negative_prompt is not None and negative_prompt_embeds is not None:
179
+ raise ValueError(
180
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
181
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
182
+ )
183
+
184
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
185
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
186
+ raise ValueError(
187
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
188
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
189
+ f" {negative_prompt_embeds.shape}."
190
+ )
191
+
192
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
193
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
194
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
195
+ if isinstance(generator, list) and len(generator) != batch_size:
196
+ raise ValueError(
197
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
198
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
199
+ )
200
+
201
+ if latents is None:
202
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
203
+ else:
204
+ latents = latents.to(device)
205
+
206
+ # scale the initial noise by the standard deviation required by the scheduler
207
+ latents = latents * self.scheduler.init_noise_sigma
208
+ return latents
209
+
210
+ @torch.no_grad()
211
+ def __call__(
212
+ self,
213
+ prompt: Union[str, List[str]],
214
+ height: Optional[int] = None,
215
+ width: Optional[int] = None,
216
+ num_inference_steps: int = 50,
217
+ guidance_scale: float = 7.5,
218
+ negative_prompt: Optional[Union[str, List[str]]] = None,
219
+ num_images_per_prompt: int = 1,
220
+ eta: float = 0.0,
221
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
222
+ latents: Optional[torch.FloatTensor] = None,
223
+ output_type: Optional[str] = "pil",
224
+ return_dict: bool = True,
225
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
226
+ callback_steps: int = 1,
227
+ editing_prompt: Optional[Union[str, List[str]]] = None,
228
+ editing_prompt_embeddings: Optional[torch.Tensor] = None,
229
+ reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
230
+ edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
231
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 10,
232
+ edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
233
+ edit_threshold: Optional[Union[float, List[float]]] = 0.9,
234
+ edit_momentum_scale: Optional[float] = 0.1,
235
+ edit_mom_beta: Optional[float] = 0.4,
236
+ edit_weights: Optional[List[float]] = None,
237
+ sem_guidance: Optional[List[torch.Tensor]] = None,
238
+
239
+ # DDPM additions
240
+ use_ddpm: bool = False,
241
+ wts: Optional[List[torch.Tensor]] = None,
242
+ zs: Optional[List[torch.Tensor]] = None
243
+ ):
244
+ r"""
245
+ Function invoked when calling the pipeline for generation.
246
+
247
+ Args:
248
+ prompt (`str` or `List[str]`):
249
+ The prompt or prompts to guide the image generation.
250
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
251
+ The height in pixels of the generated image.
252
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
253
+ The width in pixels of the generated image.
254
+ num_inference_steps (`int`, *optional*, defaults to 50):
255
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
256
+ expense of slower inference.
257
+ guidance_scale (`float`, *optional*, defaults to 7.5):
258
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
259
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
260
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
261
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
262
+ usually at the expense of lower image quality.
263
+ negative_prompt (`str` or `List[str]`, *optional*):
264
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
265
+ if `guidance_scale` is less than `1`).
266
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
267
+ The number of images to generate per prompt.
268
+ eta (`float`, *optional*, defaults to 0.0):
269
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
270
+ [`schedulers.DDIMScheduler`], will be ignored for others.
271
+ generator (`torch.Generator`, *optional*):
272
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
273
+ to make generation deterministic.
274
+ latents (`torch.FloatTensor`, *optional*):
275
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
276
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
277
+ tensor will ge generated by sampling using the supplied random `generator`.
278
+ output_type (`str`, *optional*, defaults to `"pil"`):
279
+ The output format of the generate image. Choose between
280
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
281
+ return_dict (`bool`, *optional*, defaults to `True`):
282
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
283
+ plain tuple.
284
+ callback (`Callable`, *optional*):
285
+ A function that will be called every `callback_steps` steps during inference. The function will be
286
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
287
+ callback_steps (`int`, *optional*, defaults to 1):
288
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
289
+ called at every step.
290
+ editing_prompt (`str` or `List[str]`, *optional*):
291
+ The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting
292
+ `editing_prompt = None`. Guidance direction of prompt should be specified via
293
+ `reverse_editing_direction`.
294
+ editing_prompt_embeddings (`torch.Tensor>`, *optional*):
295
+ Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be
296
+ specified via `reverse_editing_direction`.
297
+ reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
298
+ Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
299
+ edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
300
+ Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`.
301
+ `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA
302
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
303
+ edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
304
+ Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum
305
+ will still be calculated for those steps and applied once all warmup periods are over.
306
+ `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf).
307
+ edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
308
+ Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied.
309
+ edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
310
+ Threshold of semantic guidance.
311
+ edit_momentum_scale (`float`, *optional*, defaults to 0.1):
312
+ Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0
313
+ momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
314
+ than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are
315
+ finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA
316
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
317
+ edit_mom_beta (`float`, *optional*, defaults to 0.4):
318
+ Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous
319
+ momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
320
+ than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA
321
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
322
+ edit_weights (`List[float]`, *optional*, defaults to `None`):
323
+ Indicates how much each individual concept should influence the overall guidance. If no weights are
324
+ provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA
325
+ Paper](https://arxiv.org/pdf/2301.12247.pdf).
326
+ sem_guidance (`List[torch.Tensor]`, *optional*):
327
+ List of pre-generated guidance vectors to be applied at generation. Length of the list has to
328
+ correspond to `num_inference_steps`.
329
+
330
+ Returns:
331
+ [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`:
332
+ [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True,
333
+ otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the
334
+ second element is a list of `bool`s denoting whether the corresponding generated image likely represents
335
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
336
+ """
337
+ # 0. Default height and width to unet
338
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
339
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
340
+
341
+ # 1. Check inputs. Raise error if not correct
342
+ self.check_inputs(prompt, height, width, callback_steps)
343
+
344
+ # 2. Define call parameters
345
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
346
+
347
+ if editing_prompt:
348
+ enable_edit_guidance = True
349
+ if isinstance(editing_prompt, str):
350
+ editing_prompt = [editing_prompt]
351
+ enabled_editing_prompts = len(editing_prompt)
352
+ elif editing_prompt_embeddings is not None:
353
+ enable_edit_guidance = True
354
+ enabled_editing_prompts = editing_prompt_embeddings.shape[0]
355
+ else:
356
+ enabled_editing_prompts = 0
357
+ enable_edit_guidance = False
358
+
359
+ # get prompt text embeddings
360
+ text_inputs = self.tokenizer(
361
+ prompt,
362
+ padding="max_length",
363
+ max_length=self.tokenizer.model_max_length,
364
+ return_tensors="pt",
365
+ )
366
+ text_input_ids = text_inputs.input_ids
367
+
368
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
369
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
370
+ logger.warning(
371
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
372
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
373
+ )
374
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
375
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
376
+
377
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
378
+ bs_embed, seq_len, _ = text_embeddings.shape
379
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
380
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
381
+
382
+ if enable_edit_guidance:
383
+ # get safety text embeddings
384
+ if editing_prompt_embeddings is None:
385
+ edit_concepts_input = self.tokenizer(
386
+ [x for item in editing_prompt for x in repeat(item, batch_size)],
387
+ padding="max_length",
388
+ max_length=self.tokenizer.model_max_length,
389
+ return_tensors="pt",
390
+ )
391
+
392
+ edit_concepts_input_ids = edit_concepts_input.input_ids
393
+
394
+ if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length:
395
+ removed_text = self.tokenizer.batch_decode(
396
+ edit_concepts_input_ids[:, self.tokenizer.model_max_length :]
397
+ )
398
+ logger.warning(
399
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
400
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
401
+ )
402
+ edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
403
+ edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
404
+ else:
405
+ edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
406
+
407
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
408
+ bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
409
+ edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1)
410
+ edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1)
411
+
412
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
413
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
414
+ # corresponds to doing no classifier free guidance.
415
+ do_classifier_free_guidance = guidance_scale > 1.0
416
+ # get unconditional embeddings for classifier free guidance
417
+
418
+ if do_classifier_free_guidance:
419
+ uncond_tokens: List[str]
420
+ if negative_prompt is None:
421
+ uncond_tokens = [""]
422
+ elif type(prompt) is not type(negative_prompt):
423
+ raise TypeError(
424
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
425
+ f" {type(prompt)}."
426
+ )
427
+ elif isinstance(negative_prompt, str):
428
+ uncond_tokens = [negative_prompt]
429
+ elif batch_size != len(negative_prompt):
430
+ raise ValueError(
431
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
432
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
433
+ " the batch size of `prompt`."
434
+ )
435
+ else:
436
+ uncond_tokens = negative_prompt
437
+
438
+ max_length = text_input_ids.shape[-1]
439
+ uncond_input = self.tokenizer(
440
+ uncond_tokens,
441
+ padding="max_length",
442
+ max_length=max_length,
443
+ truncation=True,
444
+ return_tensors="pt",
445
+ )
446
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
447
+
448
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
449
+ seq_len = uncond_embeddings.shape[1]
450
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
451
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
452
+
453
+ # For classifier free guidance, we need to do two forward passes.
454
+ # Here we concatenate the unconditional and text embeddings into a single batch
455
+ # to avoid doing two forward passes
456
+ if enable_edit_guidance:
457
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
458
+ else:
459
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
460
+ # get the initial random noise unless the user supplied it
461
+
462
+ # 4. Prepare timesteps
463
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
464
+ timesteps = self.scheduler.timesteps
465
+ if use_ddpm:
466
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
467
+ timesteps = timesteps[-zs.shape[0]:]
468
+
469
+ # 5. Prepare latent variables
470
+ num_channels_latents = self.unet.config.in_channels
471
+ latents = self.prepare_latents(
472
+ batch_size * num_images_per_prompt,
473
+ num_channels_latents,
474
+ height,
475
+ width,
476
+ text_embeddings.dtype,
477
+ self.device,
478
+ generator,
479
+ latents,
480
+ )
481
+
482
+ # 6. Prepare extra step kwargs.
483
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
484
+
485
+ # Initialize edit_momentum to None
486
+ edit_momentum = None
487
+
488
+ self.uncond_estimates = None
489
+ self.text_estimates = None
490
+ self.edit_estimates = None
491
+ self.sem_guidance = None
492
+
493
+ for i, t in enumerate(self.progress_bar(timesteps)):
494
+ # expand the latents if we are doing classifier free guidance
495
+ latent_model_input = (
496
+ torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents
497
+ )
498
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
499
+
500
+ # predict the noise residual
501
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
502
+
503
+ # perform guidance
504
+ if do_classifier_free_guidance:
505
+ noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64]
506
+ noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
507
+ noise_pred_edit_concepts = noise_pred_out[2:]
508
+
509
+ # default text guidance
510
+ noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond)
511
+ # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0])
512
+
513
+ if self.uncond_estimates is None:
514
+ self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape))
515
+ self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
516
+
517
+ if self.text_estimates is None:
518
+ self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
519
+ self.text_estimates[i] = noise_pred_text.detach().cpu()
520
+
521
+ if self.edit_estimates is None and enable_edit_guidance:
522
+ self.edit_estimates = torch.zeros(
523
+ (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
524
+ )
525
+
526
+ if self.sem_guidance is None:
527
+ self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
528
+
529
+ if edit_momentum is None:
530
+ edit_momentum = torch.zeros_like(noise_guidance)
531
+
532
+ if enable_edit_guidance:
533
+ concept_weights = torch.zeros(
534
+ (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
535
+ device=self.device,
536
+ dtype=noise_guidance.dtype,
537
+ )
538
+ noise_guidance_edit = torch.zeros(
539
+ (len(noise_pred_edit_concepts), *noise_guidance.shape),
540
+ device=self.device,
541
+ dtype=noise_guidance.dtype,
542
+ )
543
+ # noise_guidance_edit = torch.zeros_like(noise_guidance)
544
+ warmup_inds = []
545
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
546
+ self.edit_estimates[i, c] = noise_pred_edit_concept
547
+ if isinstance(edit_guidance_scale, list):
548
+ edit_guidance_scale_c = edit_guidance_scale[c]
549
+ else:
550
+ edit_guidance_scale_c = edit_guidance_scale
551
+
552
+ if isinstance(edit_threshold, list):
553
+ edit_threshold_c = edit_threshold[c]
554
+ else:
555
+ edit_threshold_c = edit_threshold
556
+ if isinstance(reverse_editing_direction, list):
557
+ reverse_editing_direction_c = reverse_editing_direction[c]
558
+ else:
559
+ reverse_editing_direction_c = reverse_editing_direction
560
+ if edit_weights:
561
+ edit_weight_c = edit_weights[c]
562
+ else:
563
+ edit_weight_c = 1.0
564
+ if isinstance(edit_warmup_steps, list):
565
+ edit_warmup_steps_c = edit_warmup_steps[c]
566
+ else:
567
+ edit_warmup_steps_c = edit_warmup_steps
568
+
569
+ if isinstance(edit_cooldown_steps, list):
570
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
571
+ elif edit_cooldown_steps is None:
572
+ edit_cooldown_steps_c = i + 1
573
+ else:
574
+ edit_cooldown_steps_c = edit_cooldown_steps
575
+ if i >= edit_warmup_steps_c:
576
+ warmup_inds.append(c)
577
+ if i >= edit_cooldown_steps_c:
578
+ noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
579
+ continue
580
+
581
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
582
+ # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
583
+ tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
584
+
585
+ tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
586
+ if reverse_editing_direction_c:
587
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
588
+ concept_weights[c, :] = tmp_weights
589
+
590
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
591
+
592
+ # torch.quantile function expects float32
593
+ if noise_guidance_edit_tmp.dtype == torch.float32:
594
+ tmp = torch.quantile(
595
+ torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2),
596
+ edit_threshold_c,
597
+ dim=2,
598
+ keepdim=False,
599
+ )
600
+ else:
601
+ tmp = torch.quantile(
602
+ torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32),
603
+ edit_threshold_c,
604
+ dim=2,
605
+ keepdim=False,
606
+ ).to(noise_guidance_edit_tmp.dtype)
607
+
608
+ noise_guidance_edit_tmp = torch.where(
609
+ torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None],
610
+ noise_guidance_edit_tmp,
611
+ torch.zeros_like(noise_guidance_edit_tmp),
612
+ )
613
+ noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
614
+
615
+ # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
616
+
617
+ warmup_inds = torch.tensor(warmup_inds).to(self.device)
618
+ if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
619
+ concept_weights = concept_weights.to("cpu") # Offload to cpu
620
+ noise_guidance_edit = noise_guidance_edit.to("cpu")
621
+
622
+ concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
623
+ concept_weights_tmp = torch.where(
624
+ concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
625
+ )
626
+ concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
627
+ # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
628
+
629
+ noise_guidance_edit_tmp = torch.index_select(
630
+ noise_guidance_edit.to(self.device), 0, warmup_inds
631
+ )
632
+ noise_guidance_edit_tmp = torch.einsum(
633
+ "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
634
+ )
635
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp
636
+ noise_guidance = noise_guidance + noise_guidance_edit_tmp
637
+
638
+ self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
639
+
640
+ del noise_guidance_edit_tmp
641
+ del concept_weights_tmp
642
+ concept_weights = concept_weights.to(self.device)
643
+ noise_guidance_edit = noise_guidance_edit.to(self.device)
644
+
645
+ concept_weights = torch.where(
646
+ concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
647
+ )
648
+
649
+ concept_weights = torch.nan_to_num(concept_weights)
650
+
651
+ noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
652
+
653
+ noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
654
+
655
+ edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
656
+
657
+ if warmup_inds.shape[0] == len(noise_pred_edit_concepts):
658
+ noise_guidance = noise_guidance + noise_guidance_edit
659
+ self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
660
+
661
+ if sem_guidance is not None:
662
+ edit_guidance = sem_guidance[i].to(self.device)
663
+ noise_guidance = noise_guidance + edit_guidance
664
+
665
+ noise_pred = noise_pred_uncond + noise_guidance
666
+ ## ddpm ###########################################################
667
+ if use_ddpm:
668
+
669
+ idx = t_to_idx[int(t)]
670
+ z = zs[idx] if not zs is None else None
671
+
672
+ # 1. get previous step value (=t-1)
673
+ prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
674
+ # 2. compute alphas, betas
675
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
676
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
677
+ beta_prod_t = 1 - alpha_prod_t
678
+
679
+ # 3. compute predicted original sample from predicted noise also called
680
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
681
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
682
+
683
+
684
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
685
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
686
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
687
+ # variance = get_variance(model, t) #, prev_timestep)
688
+ prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
689
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
690
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
691
+ beta_prod_t = 1 - alpha_prod_t
692
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
693
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
694
+
695
+
696
+
697
+ std_dev_t = eta * variance ** (0.5)
698
+ # Take care of asymetric reverse process (asyrp)
699
+ noise_pred_direction = noise_pred
700
+
701
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
702
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
703
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * noise_pred_direction
704
+
705
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
706
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
707
+ # 8. Add noice if eta > 0
708
+ if eta > 0:
709
+ if z is None:
710
+ z = torch.randn(noise_pred.shape, device=self.device)
711
+ sigma_z = eta * variance ** (0.5) * z
712
+ latents = prev_sample + sigma_z
713
+
714
+ ## ddpm ##########################################################
715
+ # compute the previous noisy sample x_t -> x_t-1
716
+ if not use_ddpm:
717
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
718
+
719
+ # call the callback, if provided
720
+ if callback is not None and i % callback_steps == 0:
721
+ callback(i, t, latents)
722
+
723
+
724
+ # 8. Post-processing
725
+ image = self.decode_latents(latents)
726
+
727
+ # 9. Run safety checker
728
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
729
+
730
+ # 10. Convert to PIL
731
+ if output_type == "pil":
732
+ image = self.numpy_to_pil(image)
733
+
734
+ if not return_dict:
735
+ return (image, has_nsfw_concept)
736
+
737
+ #return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
738
+
739
+ # 8. Post-processing
740
+ if not output_type == "latent":
741
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
742
+ image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
743
+ else:
744
+ image = latents
745
+ has_nsfw_concept = None
746
+
747
+ if has_nsfw_concept is None:
748
+ do_denormalize = [True] * image.shape[0]
749
+ else:
750
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
751
+
752
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
753
+
754
+ if not return_dict:
755
+ return (image, has_nsfw_concept)
756
+
757
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)