AlanB commited on
Commit
78167cb
1 Parent(s): 4afe117

Add TextualInversion to processed prompts. Other updates.

Browse files
Files changed (1) hide show
  1. pipeline.py +18 -9
pipeline.py CHANGED
@@ -6,13 +6,14 @@ import numpy as np
6
  import PIL
7
  import torch
8
  from packaging import version
9
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
10
  import random
11
  import sys
12
  from tqdm.auto import tqdm
13
 
14
  import diffusers
15
  from diffusers import SchedulerMixin, StableDiffusionPipeline
 
16
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
  from diffusers.utils import logging
@@ -182,14 +183,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m
182
  return tokens, weights
183
 
184
 
185
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
186
  r"""
187
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
188
  """
189
  max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
190
  weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
191
  for i in range(len(tokens)):
192
- tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
193
  if no_boseos_middle:
194
  weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
195
  else:
@@ -320,12 +321,14 @@ def get_weighted_text_embeddings(
320
  # pad the length of tokens and weights
321
  bos = pipe.tokenizer.bos_token_id
322
  eos = pipe.tokenizer.eos_token_id
 
323
  prompt_tokens, prompt_weights = pad_tokens_and_weights(
324
  prompt_tokens,
325
  prompt_weights,
326
  max_length,
327
  bos,
328
  eos,
 
329
  no_boseos_middle=no_boseos_middle,
330
  chunk_length=pipe.tokenizer.model_max_length,
331
  )
@@ -337,6 +340,7 @@ def get_weighted_text_embeddings(
337
  max_length,
338
  bos,
339
  eos,
 
340
  no_boseos_middle=no_boseos_middle,
341
  chunk_length=pipe.tokenizer.model_max_length,
342
  )
@@ -379,7 +383,7 @@ def get_weighted_text_embeddings(
379
 
380
  def preprocess_image(image):
381
  w, h = image.size
382
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
383
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
384
  image = np.array(image).astype(np.float32) / 255.0
385
  image = image[None].transpose(0, 3, 1, 2)
@@ -390,7 +394,7 @@ def preprocess_image(image):
390
  def preprocess_mask(mask, scale_factor=8):
391
  mask = mask.convert("L")
392
  w, h = mask.size
393
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
394
  mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
395
  mask = np.array(mask).astype(np.float32) / 255.0
396
  mask = np.tile(mask, (4, 1, 1))
@@ -425,7 +429,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
425
  safety_checker ([`StableDiffusionSafetyChecker`]):
426
  Classification module that estimates whether generated images could be considered offensive or harmful.
427
  Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
428
- feature_extractor ([`CLIPFeatureExtractor`]):
429
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
430
  """
431
 
@@ -439,7 +443,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
439
  unet: UNet2DConditionModel,
440
  scheduler: SchedulerMixin,
441
  safety_checker: StableDiffusionSafetyChecker,
442
- feature_extractor: CLIPFeatureExtractor,
443
  requires_safety_checker: bool = True,
444
  ):
445
  super().__init__(
@@ -464,7 +468,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
464
  unet: UNet2DConditionModel,
465
  scheduler: SchedulerMixin,
466
  safety_checker: StableDiffusionSafetyChecker,
467
- feature_extractor: CLIPFeatureExtractor,
468
  ):
469
  super().__init__(
470
  vae=vae,
@@ -538,6 +542,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
538
  f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
539
  " the batch size of `prompt`."
540
  )
 
 
 
 
 
541
 
542
  text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
543
  pipe=self,
@@ -627,7 +636,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
627
  if image is None:
628
  shape = (
629
  batch_size,
630
- self.unet.in_channels,
631
  height // self.vae_scale_factor,
632
  width // self.vae_scale_factor,
633
  )
 
6
  import PIL
7
  import torch
8
  from packaging import version
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
  import random
11
  import sys
12
  from tqdm.auto import tqdm
13
 
14
  import diffusers
15
  from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.loaders import TextualInversionLoaderMixin
17
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
18
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
19
  from diffusers.utils import logging
 
183
  return tokens, weights
184
 
185
 
186
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
187
  r"""
188
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
189
  """
190
  max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
191
  weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
192
  for i in range(len(tokens)):
193
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
194
  if no_boseos_middle:
195
  weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
196
  else:
 
321
  # pad the length of tokens and weights
322
  bos = pipe.tokenizer.bos_token_id
323
  eos = pipe.tokenizer.eos_token_id
324
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
325
  prompt_tokens, prompt_weights = pad_tokens_and_weights(
326
  prompt_tokens,
327
  prompt_weights,
328
  max_length,
329
  bos,
330
  eos,
331
+ pad,
332
  no_boseos_middle=no_boseos_middle,
333
  chunk_length=pipe.tokenizer.model_max_length,
334
  )
 
340
  max_length,
341
  bos,
342
  eos,
343
+ pad,
344
  no_boseos_middle=no_boseos_middle,
345
  chunk_length=pipe.tokenizer.model_max_length,
346
  )
 
383
 
384
  def preprocess_image(image):
385
  w, h = image.size
386
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
387
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
388
  image = np.array(image).astype(np.float32) / 255.0
389
  image = image[None].transpose(0, 3, 1, 2)
 
394
  def preprocess_mask(mask, scale_factor=8):
395
  mask = mask.convert("L")
396
  w, h = mask.size
397
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
398
  mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
399
  mask = np.array(mask).astype(np.float32) / 255.0
400
  mask = np.tile(mask, (4, 1, 1))
 
429
  safety_checker ([`StableDiffusionSafetyChecker`]):
430
  Classification module that estimates whether generated images could be considered offensive or harmful.
431
  Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
432
+ feature_extractor ([`CLIPImageProcessor`]):
433
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
434
  """
435
 
 
443
  unet: UNet2DConditionModel,
444
  scheduler: SchedulerMixin,
445
  safety_checker: StableDiffusionSafetyChecker,
446
+ feature_extractor: CLIPImageProcessor,
447
  requires_safety_checker: bool = True,
448
  ):
449
  super().__init__(
 
468
  unet: UNet2DConditionModel,
469
  scheduler: SchedulerMixin,
470
  safety_checker: StableDiffusionSafetyChecker,
471
+ feature_extractor: CLIPImageProcessor,
472
  ):
473
  super().__init__(
474
  vae=vae,
 
542
  f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
543
  " the batch size of `prompt`."
544
  )
545
+
546
+ # textual inversion: procecss multi-vector tokens if necessary
547
+ if isinstance(self, TextualInversionLoaderMixin):
548
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
549
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
550
 
551
  text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
552
  pipe=self,
 
636
  if image is None:
637
  shape = (
638
  batch_size,
639
+ self.unet.config.in_channels,
640
  height // self.vae_scale_factor,
641
  width // self.vae_scale_factor,
642
  )