fix
Browse files- pipeline.py +2 -2
pipeline.py
CHANGED
@@ -318,11 +318,11 @@ def get_weighted_text_embeddings(
|
|
318 |
if (not skip_parsing) and (not skip_weighting):
|
319 |
previous_mean = text_embeddings.mean(axis=[-2, -1])
|
320 |
text_embeddings *= prompt_weights.unsqueeze(-1)
|
321 |
-
text_embeddings *= previous_mean / text_embeddings.mean(axis=[-2, -1])
|
322 |
if uncond_prompt is not None:
|
323 |
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
|
324 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
325 |
-
uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
|
326 |
|
327 |
if uncond_prompt is not None:
|
328 |
return text_embeddings, uncond_embeddings
|
|
|
318 |
if (not skip_parsing) and (not skip_weighting):
|
319 |
previous_mean = text_embeddings.mean(axis=[-2, -1])
|
320 |
text_embeddings *= prompt_weights.unsqueeze(-1)
|
321 |
+
text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
|
322 |
if uncond_prompt is not None:
|
323 |
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
|
324 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
325 |
+
uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
|
326 |
|
327 |
if uncond_prompt is not None:
|
328 |
return text_embeddings, uncond_embeddings
|