AlekseyCalvin commited on
Commit
3b316a5
1 Parent(s): f83b4fd

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -1
pipeline.py CHANGED
@@ -42,6 +42,7 @@ if is_torch_xla_available():
42
  else:
43
  XLA_AVAILABLE = False
44
 
 
45
 
46
  # Constants for shift calculation
47
  BASE_SEQ_LEN = 256
@@ -188,6 +189,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
188
  f" {self.tokenizer_max_length} tokens: {removed_text}"
189
  )
190
  prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
 
191
 
192
  # Use pooled output of CLIPTextModel
193
  prompt_embeds = prompt_embeds.pooler_output
@@ -196,8 +198,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
196
  # duplicate text embeddings for each generation per prompt, using mps friendly method
197
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
198
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
 
 
 
199
 
200
- return prompt_embeds
201
 
202
  def encode_prompt(
203
  self,
 
42
  else:
43
  XLA_AVAILABLE = False
44
 
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
 
47
  # Constants for shift calculation
48
  BASE_SEQ_LEN = 256
 
189
  f" {self.tokenizer_max_length} tokens: {removed_text}"
190
  )
191
  prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
192
+ pooled_prompt_embeds = prompt_embeds[0]
193
 
194
  # Use pooled output of CLIPTextModel
195
  prompt_embeds = prompt_embeds.pooler_output
 
198
  # duplicate text embeddings for each generation per prompt, using mps friendly method
199
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
200
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
201
+
202
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
203
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
204
 
205
+ return prompt_embeds, pooled_prompt_embeds
206
 
207
  def encode_prompt(
208
  self,