skytnt commited on
Commit
8896f9f
1 Parent(s): 0d2a17a

fix negative_prompt=None

Browse files
Files changed (1) hide show
  1. pipeline.py +13 -18
pipeline.py CHANGED
@@ -550,6 +550,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
550
 
551
  if isinstance(prompt, str):
552
  batch_size = 1
 
553
  elif isinstance(prompt, list):
554
  batch_size = len(prompt)
555
  else:
@@ -576,28 +577,21 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
576
  # corresponds to doing no classifier free guidance.
577
  do_classifier_free_guidance = guidance_scale > 1.0
578
  # get unconditional embeddings for classifier free guidance
579
- uncond_tokens = [""]
580
- if do_classifier_free_guidance:
581
- if type(prompt) is not type(negative_prompt):
582
- raise TypeError(
583
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
584
- f" {type(prompt)}."
585
- )
586
- elif isinstance(negative_prompt, str):
587
- uncond_tokens = [negative_prompt]
588
- elif batch_size != len(negative_prompt):
589
- raise ValueError(
590
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
591
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
592
- " the batch size of `prompt`."
593
- )
594
- else:
595
- uncond_tokens = negative_prompt
596
 
597
  text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
598
  pipe=self,
599
  prompt=prompt,
600
- uncond_prompt=uncond_tokens if do_classifier_free_guidance else None,
601
  max_embeddings_multiples=max_embeddings_multiples,
602
  **kwargs
603
  )
@@ -682,6 +676,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
682
  )
683
  else:
684
  noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
 
685
  latents = self.scheduler.add_noise(init_latents, noise, timesteps)
686
 
687
  t_start = max(num_inference_steps - init_timestep + offset, 0)
 
550
 
551
  if isinstance(prompt, str):
552
  batch_size = 1
553
+ prompt = [prompt]
554
  elif isinstance(prompt, list):
555
  batch_size = len(prompt)
556
  else:
 
577
  # corresponds to doing no classifier free guidance.
578
  do_classifier_free_guidance = guidance_scale > 1.0
579
  # get unconditional embeddings for classifier free guidance
580
+ if negative_prompt is None:
581
+ negative_prompt = [""] * batch_size
582
+ elif isinstance(negative_prompt, str):
583
+ negative_prompt = [negative_prompt] * batch_size
584
+ if batch_size != len(negative_prompt):
585
+ raise ValueError(
586
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
587
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
588
+ " the batch size of `prompt`."
589
+ )
 
 
 
 
 
 
 
590
 
591
  text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
592
  pipe=self,
593
  prompt=prompt,
594
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
595
  max_embeddings_multiples=max_embeddings_multiples,
596
  **kwargs
597
  )
 
676
  )
677
  else:
678
  noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
679
+ print(timesteps.shape)
680
  latents = self.scheduler.add_noise(init_latents, noise, timesteps)
681
 
682
  t_start = max(num_inference_steps - init_timestep + offset, 0)