AlanB commited on
Commit
7e55ab8
1 Parent(s): af04e1a

Fixed mistake

Browse files
Files changed (1) hide show
  1. pipeline.py +228 -232
pipeline.py CHANGED
@@ -1,78 +1,94 @@
1
- """
2
- modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
3
- """
4
  import inspect
5
- import warnings
6
- from typing import Callable, List, Optional, Union
7
 
8
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
- from diffusers.pipeline_utils import DiffusionPipeline
12
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
15
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
16
-
17
-
18
- class ComposableStableDiffusionPipeline(DiffusionPipeline):
19
- r"""
20
- Pipeline for text-to-image generation using Stable Diffusion.
21
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
22
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
23
- Args:
24
- vae ([`AutoencoderKL`]):
25
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
26
- text_encoder ([`CLIPTextModel`]):
27
- Frozen text-encoder. Stable Diffusion uses the text portion of
28
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
29
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
30
- tokenizer (`CLIPTokenizer`):
31
- Tokenizer of class
32
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
33
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
34
- scheduler ([`SchedulerMixin`]):
35
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
36
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
37
- safety_checker ([`StableDiffusionSafetyChecker`]):
38
- Classification module that estimates whether generated images could be considered offsensive or harmful.
39
- Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
40
- feature_extractor ([`CLIPFeatureExtractor`]):
41
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
42
  """
43
 
44
  def __init__(
45
  self,
46
  vae: AutoencoderKL,
47
  text_encoder: CLIPTextModel,
 
48
  tokenizer: CLIPTokenizer,
49
  unet: UNet2DConditionModel,
50
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler],
51
- safety_checker: StableDiffusionSafetyChecker,
52
  feature_extractor: CLIPFeatureExtractor,
53
  ):
54
  super().__init__()
55
  self.register_modules(
56
  vae=vae,
57
  text_encoder=text_encoder,
 
58
  tokenizer=tokenizer,
59
  unet=unet,
60
  scheduler=scheduler,
61
- safety_checker=safety_checker,
62
  feature_extractor=feature_extractor,
63
  )
64
 
 
 
 
 
 
 
 
 
 
 
 
65
  def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
66
- r"""
67
- Enable sliced attention computation.
68
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
69
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
70
- Args:
71
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
72
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
73
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
74
- `attention_head_dim` must be a multiple of `slice_size`.
75
- """
76
  if slice_size == "auto":
77
  if isinstance(self.unet.config.attention_head_dim, int):
78
  # half the attention head size is usually a good trade-off between
@@ -81,30 +97,92 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
81
  else:
82
  # if `attention_head_dim` is a list, take the smallest head size
83
  slice_size = min(self.unet.config.attention_head_dim)
 
84
  self.unet.set_attention_slice(slice_size)
85
 
86
  def disable_attention_slicing(self):
87
- r"""
88
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
89
- back to computing attention in one step.
90
- """
91
- # set slice_size = `None` to disable `attention slicing`
92
  self.enable_attention_slicing(None)
93
 
94
- def enable_vae_slicing(self):
95
- r"""
96
- Enable sliced VAE decoding.
97
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
98
- steps. This is useful to save some memory and allow larger batch sizes.
99
- """
100
- self.vae.enable_slicing()
101
-
102
- def disable_vae_slicing(self):
103
- r"""
104
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
105
- computing decoding in one step.
106
- """
107
- self.vae.disable_slicing()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  @torch.no_grad()
110
  def __call__(
@@ -114,76 +192,17 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
114
  width: Optional[int] = 512,
115
  num_inference_steps: Optional[int] = 50,
116
  guidance_scale: Optional[float] = 7.5,
117
- eta: Optional[float] = 0.0,
 
 
 
 
 
118
  generator: Optional[torch.Generator] = None,
119
  latents: Optional[torch.FloatTensor] = None,
120
  output_type: Optional[str] = "pil",
121
  return_dict: bool = True,
122
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
123
- callback_steps: Optional[int] = 1,
124
- weights: Optional[str] = "",
125
- **kwargs,
126
  ):
127
- r"""
128
- Function invoked when calling the pipeline for generation.
129
- Args:
130
- prompt (`str` or `List[str]`):
131
- The prompt or prompts to guide the image generation.
132
- height (`int`, *optional*, defaults to 512):
133
- The height in pixels of the generated image.
134
- width (`int`, *optional*, defaults to 512):
135
- The width in pixels of the generated image.
136
- num_inference_steps (`int`, *optional*, defaults to 50):
137
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
138
- expense of slower inference.
139
- guidance_scale (`float`, *optional*, defaults to 7.5):
140
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
141
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
142
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
143
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
144
- usually at the expense of lower image quality.
145
- eta (`float`, *optional*, defaults to 0.0):
146
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
147
- [`schedulers.DDIMScheduler`], will be ignored for others.
148
- generator (`torch.Generator`, *optional*):
149
- A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
150
- deterministic.
151
- latents (`torch.FloatTensor`, *optional*):
152
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
153
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
154
- tensor will ge generated by sampling using the supplied random `generator`.
155
- output_type (`str`, *optional*, defaults to `"pil"`):
156
- The output format of the generate image. Choose between
157
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
158
- return_dict (`bool`, *optional*, defaults to `True`):
159
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
160
- plain tuple.
161
- callback (`Callable`, *optional*):
162
- A function that will be called every `callback_steps` steps during inference. The function will be
163
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
164
- callback_steps (`int`, *optional*, defaults to 1):
165
- The frequency at which the `callback` function will be called. If not specified, the callback will be
166
- called at every step.
167
- Returns:
168
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
169
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
170
- When returning a tuple, the first element is a list with the generated images, and the second element is a
171
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
172
- (nsfw) content, according to the `safety_checker`.
173
- """
174
-
175
- if "torch_device" in kwargs:
176
- device = kwargs.pop("torch_device")
177
- warnings.warn(
178
- "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
179
- " Consider using `pipe.to(torch_device)` instead."
180
- )
181
-
182
- # Set device as before (to be removed in 0.3.0)
183
- if device is None:
184
- device = "cuda" if torch.cuda.is_available() else "cpu"
185
- self.to(device)
186
-
187
  if isinstance(prompt, str):
188
  batch_size = 1
189
  elif isinstance(prompt, list):
@@ -194,10 +213,6 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
194
  if height % 8 != 0 or width % 8 != 0:
195
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
196
 
197
- if "|" in prompt:
198
- prompt = [x.strip() for x in prompt.split("|")]
199
- print(f"composing {prompt}...")
200
-
201
  # get prompt text embeddings
202
  text_input = self.tokenizer(
203
  prompt,
@@ -207,39 +222,24 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
207
  return_tensors="pt",
208
  )
209
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
210
-
211
- if not weights:
212
- # specify weights for prompts (excluding the unconditional score)
213
- print("using equal weights for all prompts...")
214
- pos_weights = torch.tensor(
215
- [1 / (text_embeddings.shape[0] - 1)] * (text_embeddings.shape[0] - 1), device=self.device
216
- ).reshape(-1, 1, 1, 1)
217
- neg_weights = torch.tensor([1.0], device=self.device).reshape(-1, 1, 1, 1)
218
- mask = torch.tensor([False] + [True] * pos_weights.shape[0], dtype=torch.bool)
219
- else:
220
- # set prompt weight for each
221
- num_prompts = len(prompt) if isinstance(prompt, list) else 1
222
- weights = [float(w.strip()) for w in weights.split("|")]
223
- if len(weights) < num_prompts:
224
- weights.append(1.0)
225
- weights = torch.tensor(weights, device=self.device)
226
- assert len(weights) == text_embeddings.shape[0], "weights specified are not equal to the number of prompts"
227
- pos_weights = []
228
- neg_weights = []
229
- mask = [] # first one is unconditional score
230
- for w in weights:
231
- if w > 0:
232
- pos_weights.append(w)
233
- mask.append(True)
234
- else:
235
- neg_weights.append(abs(w))
236
- mask.append(False)
237
- # normalize the weights
238
- pos_weights = torch.tensor(pos_weights, device=self.device).reshape(-1, 1, 1, 1)
239
- pos_weights = pos_weights / pos_weights.sum()
240
- neg_weights = torch.tensor(neg_weights, device=self.device).reshape(-1, 1, 1, 1)
241
- neg_weights = neg_weights / neg_weights.sum()
242
- mask = torch.tensor(mask, device=self.device, dtype=torch.bool)
243
 
244
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
245
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -248,40 +248,35 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
248
  # get unconditional embeddings for classifier free guidance
249
  if do_classifier_free_guidance:
250
  max_length = text_input.input_ids.shape[-1]
 
 
 
 
251
 
252
- if torch.all(mask):
253
- # no negative prompts, so we use empty string as the negative prompt
254
- uncond_input = self.tokenizer(
255
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
256
- )
257
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
258
-
259
- # For classifier free guidance, we need to do two forward passes.
260
- # Here we concatenate the unconditional and text embeddings into a single batch
261
- # to avoid doing two forward passes
262
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
263
-
264
- # update negative weights
265
- neg_weights = torch.tensor([1.0], device=self.device)
266
- mask = torch.tensor([False] + mask.detach().tolist(), device=self.device, dtype=torch.bool)
267
 
268
  # get the initial random noise unless the user supplied it
269
 
270
  # Unlike in other pipelines, latents need to be generated in the target device
271
  # for 1-to-1 results reproducibility with the CompVis implementation.
272
  # However this currently doesn't work in `mps`.
273
- latents_device = "cpu" if self.device.type == "mps" else self.device
274
- latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
275
  if latents is None:
276
- latents = torch.randn(
277
- latents_shape,
278
- generator=generator,
279
- device=latents_device,
280
- )
 
 
281
  else:
282
  if latents.shape != latents_shape:
283
  raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
284
- latents = latents.to(self.device)
285
 
286
  # set timesteps
287
  accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
@@ -291,9 +286,12 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
291
 
292
  self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
293
 
294
- # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
295
- if isinstance(self.scheduler, LMSDiscreteScheduler):
296
- latents = latents * self.scheduler.sigmas[0]
 
 
 
297
 
298
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
299
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -304,41 +302,43 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
304
  if accepts_eta:
305
  extra_step_kwargs["eta"] = eta
306
 
307
- for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
 
 
 
 
 
308
  # expand the latents if we are doing classifier free guidance
309
- latent_model_input = (
310
- torch.cat([latents] * text_embeddings.shape[0]) if do_classifier_free_guidance else latents
311
- )
312
- if isinstance(self.scheduler, LMSDiscreteScheduler):
313
- sigma = self.scheduler.sigmas[i]
314
- # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
315
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
316
-
317
- # reduce memory by predicting each score sequentially
318
- noise_preds = []
319
  # predict the noise residual
320
- for latent_in, text_embedding_in in zip(
321
- torch.chunk(latent_model_input, chunks=latent_model_input.shape[0], dim=0),
322
- torch.chunk(text_embeddings, chunks=text_embeddings.shape[0], dim=0),
323
- ):
324
- noise_preds.append(self.unet(latent_in, t, encoder_hidden_states=text_embedding_in).sample)
325
- noise_preds = torch.cat(noise_preds, dim=0)
326
-
327
- # perform guidance
328
  if do_classifier_free_guidance:
329
- noise_pred_uncond = (noise_preds[~mask] * neg_weights).sum(dim=0, keepdims=True)
330
- noise_pred_text = (noise_preds[mask] * pos_weights).sum(dim=0, keepdims=True)
331
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  # compute the previous noisy sample x_t -> x_t-1
334
- if isinstance(self.scheduler, LMSDiscreteScheduler):
335
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
336
- else:
337
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
338
-
339
- # call the callback, if provided
340
- if callback is not None and i % callback_steps == 0:
341
- callback(i, t, latents)
342
 
343
  # scale and decode the image latents with vae
344
  latents = 1 / 0.18215 * latents
@@ -347,14 +347,10 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
347
  image = (image / 2 + 0.5).clamp(0, 1)
348
  image = image.cpu().permute(0, 2, 3, 1).numpy()
349
 
350
- # run safety checker
351
- safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
352
- image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
353
-
354
  if output_type == "pil":
355
  image = self.numpy_to_pil(image)
356
 
357
  if not return_dict:
358
- return (image, has_nsfw_concept)
359
 
360
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
1
  import inspect
2
+ from typing import List, Optional, Union
 
3
 
4
  import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from diffusers import (
9
+ AutoencoderKL,
10
+ DDIMScheduler,
11
+ DiffusionPipeline,
12
+ LMSDiscreteScheduler,
13
+ PNDMScheduler,
14
+ UNet2DConditionModel,
15
+ )
16
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
17
+ from torchvision import transforms
18
+ from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
19
+
20
+
21
+ class MakeCutouts(nn.Module):
22
+ def __init__(self, cut_size, cut_power=1.0):
23
+ super().__init__()
24
+
25
+ self.cut_size = cut_size
26
+ self.cut_power = cut_power
27
+
28
+ def forward(self, pixel_values, num_cutouts):
29
+ sideY, sideX = pixel_values.shape[2:4]
30
+ max_size = min(sideX, sideY)
31
+ min_size = min(sideX, sideY, self.cut_size)
32
+ cutouts = []
33
+ for _ in range(num_cutouts):
34
+ size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
35
+ offsetx = torch.randint(0, sideX - size + 1, ())
36
+ offsety = torch.randint(0, sideY - size + 1, ())
37
+ cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
38
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
39
+ return torch.cat(cutouts)
40
+
41
+
42
+ def spherical_dist_loss(x, y):
43
+ x = F.normalize(x, dim=-1)
44
+ y = F.normalize(y, dim=-1)
45
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
46
+
47
+
48
+ def set_requires_grad(model, value):
49
+ for param in model.parameters():
50
+ param.requires_grad = value
51
 
52
+
53
+ class CLIPGuidedStableDiffusion(DiffusionPipeline):
54
+ """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
55
+ - https://github.com/Jack000/glid-3-xl
56
+ - https://github.dev/crowsonkb/k-diffusion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  """
58
 
59
  def __init__(
60
  self,
61
  vae: AutoencoderKL,
62
  text_encoder: CLIPTextModel,
63
+ clip_model: CLIPModel,
64
  tokenizer: CLIPTokenizer,
65
  unet: UNet2DConditionModel,
66
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
 
67
  feature_extractor: CLIPFeatureExtractor,
68
  ):
69
  super().__init__()
70
  self.register_modules(
71
  vae=vae,
72
  text_encoder=text_encoder,
73
+ clip_model=clip_model,
74
  tokenizer=tokenizer,
75
  unet=unet,
76
  scheduler=scheduler,
 
77
  feature_extractor=feature_extractor,
78
  )
79
 
80
+ self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
81
+ cut_out_size = (
82
+ feature_extractor.size
83
+ if isinstance(feature_extractor.size, int)
84
+ else feature_extractor.size["shortest_edge"]
85
+ )
86
+ self.make_cutouts = MakeCutouts(cut_out_size)
87
+
88
+ set_requires_grad(self.text_encoder, False)
89
+ set_requires_grad(self.clip_model, False)
90
+
91
  def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
 
 
 
 
 
 
 
 
 
 
92
  if slice_size == "auto":
93
  if isinstance(self.unet.config.attention_head_dim, int):
94
  # half the attention head size is usually a good trade-off between
 
97
  else:
98
  # if `attention_head_dim` is a list, take the smallest head size
99
  slice_size = min(self.unet.config.attention_head_dim)
100
+
101
  self.unet.set_attention_slice(slice_size)
102
 
103
  def disable_attention_slicing(self):
 
 
 
 
 
104
  self.enable_attention_slicing(None)
105
 
106
+ def freeze_vae(self):
107
+ set_requires_grad(self.vae, False)
108
+
109
+ def unfreeze_vae(self):
110
+ set_requires_grad(self.vae, True)
111
+
112
+ def freeze_unet(self):
113
+ set_requires_grad(self.unet, False)
114
+
115
+ def unfreeze_unet(self):
116
+ set_requires_grad(self.unet, True)
117
+
118
+ @torch.enable_grad()
119
+ def cond_fn(
120
+ self,
121
+ latents,
122
+ timestep,
123
+ index,
124
+ text_embeddings,
125
+ noise_pred_original,
126
+ text_embeddings_clip,
127
+ clip_guidance_scale,
128
+ num_cutouts,
129
+ use_cutouts=True,
130
+ ):
131
+ latents = latents.detach().requires_grad_()
132
+
133
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
134
+ sigma = self.scheduler.sigmas[index]
135
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
136
+ latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
137
+ else:
138
+ latent_model_input = latents
139
+
140
+ # predict the noise residual
141
+ noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
142
+
143
+ if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
144
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
145
+ beta_prod_t = 1 - alpha_prod_t
146
+ # compute predicted original sample from predicted noise also called
147
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
148
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
149
+
150
+ fac = torch.sqrt(beta_prod_t)
151
+ sample = pred_original_sample * (fac) + latents * (1 - fac)
152
+ elif isinstance(self.scheduler, LMSDiscreteScheduler):
153
+ sigma = self.scheduler.sigmas[index]
154
+ sample = latents - sigma * noise_pred
155
+ else:
156
+ raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
157
+
158
+ sample = 1 / 0.18215 * sample
159
+ image = self.vae.decode(sample).sample
160
+ image = (image / 2 + 0.5).clamp(0, 1)
161
+
162
+ if use_cutouts:
163
+ image = self.make_cutouts(image, num_cutouts)
164
+ else:
165
+ image = transforms.Resize(self.feature_extractor.size)(image)
166
+ image = self.normalize(image).to(latents.dtype)
167
+
168
+ image_embeddings_clip = self.clip_model.get_image_features(image)
169
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
170
+
171
+ if use_cutouts:
172
+ dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
173
+ dists = dists.view([num_cutouts, sample.shape[0], -1])
174
+ loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
175
+ else:
176
+ loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
177
+
178
+ grads = -torch.autograd.grad(loss, latents)[0]
179
+
180
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
181
+ latents = latents.detach() + grads * (sigma**2)
182
+ noise_pred = noise_pred_original
183
+ else:
184
+ noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
185
+ return noise_pred, latents
186
 
187
  @torch.no_grad()
188
  def __call__(
 
192
  width: Optional[int] = 512,
193
  num_inference_steps: Optional[int] = 50,
194
  guidance_scale: Optional[float] = 7.5,
195
+ num_images_per_prompt: Optional[int] = 1,
196
+ eta: float = 0.0,
197
+ clip_guidance_scale: Optional[float] = 100,
198
+ clip_prompt: Optional[Union[str, List[str]]] = None,
199
+ num_cutouts: Optional[int] = 4,
200
+ use_cutouts: Optional[bool] = True,
201
  generator: Optional[torch.Generator] = None,
202
  latents: Optional[torch.FloatTensor] = None,
203
  output_type: Optional[str] = "pil",
204
  return_dict: bool = True,
 
 
 
 
205
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  if isinstance(prompt, str):
207
  batch_size = 1
208
  elif isinstance(prompt, list):
 
213
  if height % 8 != 0 or width % 8 != 0:
214
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
215
 
 
 
 
 
216
  # get prompt text embeddings
217
  text_input = self.tokenizer(
218
  prompt,
 
222
  return_tensors="pt",
223
  )
224
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
225
+ # duplicate text embeddings for each generation per prompt
226
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
227
+
228
+ if clip_guidance_scale > 0:
229
+ if clip_prompt is not None:
230
+ clip_text_input = self.tokenizer(
231
+ clip_prompt,
232
+ padding="max_length",
233
+ max_length=self.tokenizer.model_max_length,
234
+ truncation=True,
235
+ return_tensors="pt",
236
+ ).input_ids.to(self.device)
237
+ else:
238
+ clip_text_input = text_input.input_ids.to(self.device)
239
+ text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
240
+ text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
241
+ # duplicate text embeddings clip for each generation per prompt
242
+ text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
245
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
 
248
  # get unconditional embeddings for classifier free guidance
249
  if do_classifier_free_guidance:
250
  max_length = text_input.input_ids.shape[-1]
251
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
252
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
253
+ # duplicate unconditional embeddings for each generation per prompt
254
+ uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
255
 
256
+ # For classifier free guidance, we need to do two forward passes.
257
+ # Here we concatenate the unconditional and text embeddings into a single batch
258
+ # to avoid doing two forward passes
259
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  # get the initial random noise unless the user supplied it
262
 
263
  # Unlike in other pipelines, latents need to be generated in the target device
264
  # for 1-to-1 results reproducibility with the CompVis implementation.
265
  # However this currently doesn't work in `mps`.
266
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
267
+ latents_dtype = text_embeddings.dtype
268
  if latents is None:
269
+ if self.device.type == "mps":
270
+ # randn does not work reproducibly on mps
271
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
272
+ self.device
273
+ )
274
+ else:
275
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
276
  else:
277
  if latents.shape != latents_shape:
278
  raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
279
+ latents = latents.to(self.device)
280
 
281
  # set timesteps
282
  accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
 
286
 
287
  self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
288
 
289
+ # Some schedulers like PNDM have timesteps as arrays
290
+ # It's more optimized to move all timesteps to correct device beforehand
291
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
292
+
293
+ # scale the initial noise by the standard deviation required by the scheduler
294
+ latents = latents * self.scheduler.init_noise_sigma
295
 
296
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
297
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
 
302
  if accepts_eta:
303
  extra_step_kwargs["eta"] = eta
304
 
305
+ # check if the scheduler accepts generator
306
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
307
+ if accepts_generator:
308
+ extra_step_kwargs["generator"] = generator
309
+
310
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
311
  # expand the latents if we are doing classifier free guidance
312
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
313
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
314
+
 
 
 
 
 
 
 
315
  # predict the noise residual
316
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
317
+
318
+ # perform classifier free guidance
 
 
 
 
 
319
  if do_classifier_free_guidance:
320
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 
321
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
322
 
323
+ # perform clip guidance
324
+ if clip_guidance_scale > 0:
325
+ text_embeddings_for_guidance = (
326
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
327
+ )
328
+ noise_pred, latents = self.cond_fn(
329
+ latents,
330
+ t,
331
+ i,
332
+ text_embeddings_for_guidance,
333
+ noise_pred,
334
+ text_embeddings_clip,
335
+ clip_guidance_scale,
336
+ num_cutouts,
337
+ use_cutouts,
338
+ )
339
+
340
  # compute the previous noisy sample x_t -> x_t-1
341
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
 
 
 
 
 
 
342
 
343
  # scale and decode the image latents with vae
344
  latents = 1 / 0.18215 * latents
 
347
  image = (image / 2 + 0.5).clamp(0, 1)
348
  image = image.cpu().permute(0, 2, 3, 1).numpy()
349
 
 
 
 
 
350
  if output_type == "pil":
351
  image = self.numpy_to_pil(image)
352
 
353
  if not return_dict:
354
+ return (image, None)
355
 
356
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)