skytnt commited on
Commit
0d2a17a
1 Parent(s): 28497e2
Files changed (1) hide show
  1. pipeline.py +2 -2
pipeline.py CHANGED
@@ -318,11 +318,11 @@ def get_weighted_text_embeddings(
318
  if (not skip_parsing) and (not skip_weighting):
319
  previous_mean = text_embeddings.mean(axis=[-2, -1])
320
  text_embeddings *= prompt_weights.unsqueeze(-1)
321
- text_embeddings *= previous_mean / text_embeddings.mean(axis=[-2, -1])
322
  if uncond_prompt is not None:
323
  previous_mean = uncond_embeddings.mean(axis=[-2, -1])
324
  uncond_embeddings *= uncond_weights.unsqueeze(-1)
325
- uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
326
 
327
  if uncond_prompt is not None:
328
  return text_embeddings, uncond_embeddings
 
318
  if (not skip_parsing) and (not skip_weighting):
319
  previous_mean = text_embeddings.mean(axis=[-2, -1])
320
  text_embeddings *= prompt_weights.unsqueeze(-1)
321
+ text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
322
  if uncond_prompt is not None:
323
  previous_mean = uncond_embeddings.mean(axis=[-2, -1])
324
  uncond_embeddings *= uncond_weights.unsqueeze(-1)
325
+ uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
326
 
327
  if uncond_prompt is not None:
328
  return text_embeddings, uncond_embeddings