skytnt commited on
Commit
28497e2
1 Parent(s): 69d52d4

fix broken num_images_per_prompt

Browse files
Files changed (1) hide show
  1. pipeline.py +12 -7
pipeline.py CHANGED
@@ -324,13 +324,9 @@ def get_weighted_text_embeddings(
324
  uncond_embeddings *= uncond_weights.unsqueeze(-1)
325
  uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
326
 
327
- # For classifier free guidance, we need to do two forward passes.
328
- # Here we concatenate the unconditional and text embeddings into a single batch
329
- # to avoid doing two forward passes
330
  if uncond_prompt is not None:
331
- text_embeddings = torch.concat([uncond_embeddings, text_embeddings])
332
-
333
- return text_embeddings
334
 
335
 
336
  def preprocess_image(image):
@@ -598,13 +594,22 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
598
  else:
599
  uncond_tokens = negative_prompt
600
 
601
- text_embeddings = get_weighted_text_embeddings(
602
  pipe=self,
603
  prompt=prompt,
604
  uncond_prompt=uncond_tokens if do_classifier_free_guidance else None,
605
  max_embeddings_multiples=max_embeddings_multiples,
606
  **kwargs
607
  )
 
 
 
 
 
 
 
 
 
608
 
609
  # set timesteps
610
  self.scheduler.set_timesteps(num_inference_steps)
 
324
  uncond_embeddings *= uncond_weights.unsqueeze(-1)
325
  uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
326
 
 
 
 
327
  if uncond_prompt is not None:
328
+ return text_embeddings, uncond_embeddings
329
+ return text_embeddings, None
 
330
 
331
 
332
  def preprocess_image(image):
 
594
  else:
595
  uncond_tokens = negative_prompt
596
 
597
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
598
  pipe=self,
599
  prompt=prompt,
600
  uncond_prompt=uncond_tokens if do_classifier_free_guidance else None,
601
  max_embeddings_multiples=max_embeddings_multiples,
602
  **kwargs
603
  )
604
+ bs_embed, seq_len, _ = text_embeddings.shape
605
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
606
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
607
+
608
+ if do_classifier_free_guidance:
609
+ bs_embed, seq_len, _ = uncond_embeddings.shape
610
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
611
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
612
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
613
 
614
  # set timesteps
615
  self.scheduler.set_timesteps(num_inference_steps)