AlekseyCalvin commited on
Commit
71527d9
1 Parent(s): ff92498

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +218 -67
pipeline.py CHANGED
@@ -223,27 +223,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
223
  clip_skip: Optional[int] = None,
224
  max_sequence_length: int = 512,
225
  lora_scale: Optional[float] = None,
226
- ):
227
- r"""
228
-
229
- Args:
230
- prompt (`str` or `List[str]`, *optional*):
231
- prompt_2 (`str` or `List[str]`, *optional*):
232
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
233
- used in all text-encoders
234
- device: (`torch.device`):
235
- torch device
236
- num_images_per_prompt (`int`):
237
- number of images that should be generated per prompt
238
- prompt_embeds (`torch.FloatTensor`, *optional*):
239
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
240
- provided, text embeddings will be generated from `prompt` input argument.
241
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
242
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
243
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
244
- lora_scale (`float`, *optional*):
245
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
246
- """
247
  device = device or self._execution_device
248
 
249
  if device is None:
@@ -297,7 +277,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
297
  batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
298
  )
299
 
300
- if prompt is not None and type(prompt) is not type(negative_prompt):
301
  raise TypeError(
302
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
303
  f" {type(prompt)}."
@@ -309,29 +289,29 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
309
  " the batch size of `prompt`."
310
  )
311
 
312
- negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
313
- negative_prompt,
314
- device=device,
315
- num_images_per_prompt=num_images_per_prompt,
316
- clip_skip=None,
317
  )
318
- negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
319
 
320
- t5_negative_prompt_embed = self._get_t5_prompt_embeds(
321
- prompt=negative_prompt_2,
322
- num_images_per_prompt=num_images_per_prompt,
323
- max_sequence_length=max_sequence_length,
324
- device=device,
325
  )
326
 
327
- negative_clip_prompt_embeds = torch.nn.functional.pad(
328
  negative_clip_prompt_embeds,
329
  (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
330
  )
331
 
332
- negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
333
- negative_pooled_prompt_embeds = torch.cat(
334
- [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
335
  )
336
 
337
  if self.text_encoder is not None:
@@ -343,26 +323,8 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
343
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
344
 
345
  return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
346
-
347
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
348
- def prepare_extra_step_kwargs(self, generator, eta):
349
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
350
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
351
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
352
- # and should be between [0, 1]
353
-
354
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
355
- extra_step_kwargs = {}
356
- if accepts_eta:
357
- extra_step_kwargs["eta"] = eta
358
-
359
- # check if the scheduler accepts generator
360
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
361
- if accepts_generator:
362
- extra_step_kwargs["generator"] = generator
363
- return extra_step_kwargs
364
-
365
- def check_inputs(
366
  self,
367
  prompt,
368
  prompt_2,
@@ -464,6 +426,23 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
464
  latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
465
 
466
  return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
  def enable_vae_slicing(self):
469
  r"""
@@ -546,7 +525,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
546
  @property
547
  def interrupt(self):
548
  return self._interrupt
549
-
550
  @torch.no_grad()
551
  @torch.inference_mode()
552
  def generate_image(
@@ -652,6 +631,178 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
652
  # Handle guidance
653
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  # 6. Denoising loop
656
  with self.progress_bar(total=num_inference_steps) as progress_bar:
657
  for i, t in enumerate(timesteps):
@@ -694,18 +845,18 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
694
  # Yield intermediate result
695
  torch.cuda.empty_cache()
696
 
697
- if latents.dtype != latents_dtype:
698
- if torch.backends.mps.is_available():
699
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
700
- latents = latents.to(latents_dtype)
701
 
702
- if callback_on_step_end is not None:
703
- callback_kwargs = {}
704
  for k in callback_on_step_end_tensor_inputs:
705
  callback_kwargs[k] = locals()[k]
706
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
707
 
708
- latents = callback_outputs.pop("latents", latents)
709
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
710
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
711
  negative_pooled_prompt_embeds = callback_outputs.pop(
@@ -713,10 +864,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
713
  )
714
 
715
  # call the callback, if provided
716
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
717
  progress_bar.update()
718
-
719
- # Final image
720
  return self._decode_latents_to_image(latents, height, width, output_type)
721
  self.maybe_free_model_hooks()
722
  torch.cuda.empty_cache()
 
223
  clip_skip: Optional[int] = None,
224
  max_sequence_length: int = 512,
225
  lora_scale: Optional[float] = None,
226
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  device = device or self._execution_device
228
 
229
  if device is None:
 
277
  batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
278
  )
279
 
280
+ if prompt is not None and type(prompt) is not type(negative_prompt):
281
  raise TypeError(
282
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
283
  f" {type(prompt)}."
 
289
  " the batch size of `prompt`."
290
  )
291
 
292
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
293
+ negative_prompt,
294
+ device=device,
295
+ num_images_per_prompt=num_images_per_prompt,
296
+ clip_skip=None,
297
  )
298
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
299
 
300
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
301
+ prompt=negative_prompt_2,
302
+ num_images_per_prompt=num_images_per_prompt,
303
+ max_sequence_length=max_sequence_length,
304
+ device=device,
305
  )
306
 
307
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
308
  negative_clip_prompt_embeds,
309
  (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
310
  )
311
 
312
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
313
+ negative_pooled_prompt_embeds = torch.cat(
314
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
315
  )
316
 
317
  if self.text_encoder is not None:
 
323
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
324
 
325
  return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
326
+
327
+ def check_inputs(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  self,
329
  prompt,
330
  prompt_2,
 
426
  latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
427
 
428
  return latents
429
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
430
+ def prepare_extra_step_kwargs(self, generator, eta):
431
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
432
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
433
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
434
+ # and should be between [0, 1]
435
+
436
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
437
+ extra_step_kwargs = {}
438
+ if accepts_eta:
439
+ extra_step_kwargs["eta"] = eta
440
+
441
+ # check if the scheduler accepts generator
442
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
443
+ if accepts_generator:
444
+ extra_step_kwargs["generator"] = generator
445
+ return extra_step_kwargs
446
 
447
  def enable_vae_slicing(self):
448
  r"""
 
525
  @property
526
  def interrupt(self):
527
  return self._interrupt
528
+
529
  @torch.no_grad()
530
  @torch.inference_mode()
531
  def generate_image(
 
631
  # Handle guidance
632
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
633
 
634
+ # 6. Denoising loop
635
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
636
+ for i, t in enumerate(timesteps):
637
+ if self.interrupt:
638
+ continue
639
+
640
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
641
+
642
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
643
+
644
+ noise_pred = self.transformer(
645
+ hidden_states=latent_model_input,
646
+ timestep=timestep / 1000,
647
+ guidance=guidance,
648
+ pooled_projections=pooled_prompt_embeds,
649
+ encoder_hidden_states=prompt_embeds,
650
+ txt_ids=text_ids,
651
+ img_ids=latent_image_ids,
652
+ joint_attention_kwargs=self.joint_attention_kwargs,
653
+ return_dict=False,
654
+ )[0]
655
+
656
+ noise_pred_uncond = self.transformer(
657
+ hidden_states=latents,
658
+ timestep=timestep / 1000,
659
+ guidance=guidance,
660
+ pooled_projections=negative_pooled_prompt_embeds,
661
+ encoder_hidden_states=negative_prompt_embeds,
662
+ img_ids=latent_image_ids,
663
+ joint_attention_kwargs=self.joint_attention_kwargs,
664
+ return_dict=False,
665
+ )[0]
666
+
667
+ if self.do_classifier_free_guidance:
668
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
669
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
670
+
671
+ latents_dtype = latents.dtype
672
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
673
+ # Yield intermediate result
674
+ torch.cuda.empty_cache()
675
+
676
+ if latents.dtype != latents_dtype:
677
+ if torch.backends.mps.is_available():
678
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
679
+ latents = latents.to(latents_dtype)
680
+
681
+ if callback_on_step_end is not None:
682
+ callback_kwargs = {}
683
+ for k in callback_on_step_end_tensor_inputs:
684
+ callback_kwargs[k] = locals()[k]
685
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
686
+
687
+ latents = callback_outputs.pop("latents", latents)
688
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
689
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
690
+ negative_pooled_prompt_embeds = callback_outputs.pop(
691
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
692
+ )
693
+
694
+ # call the callback, if provided
695
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
696
+ progress_bar.update()
697
+
698
+ # Final image
699
+ return self._decode_latents_to_image(latents, height, width, output_type)
700
+ self.maybe_free_model_hooks()
701
+ torch.cuda.empty_cache()
702
+
703
+ def __call__(
704
+ self,
705
+ prompt: Union[str, List[str]] = None,
706
+ prompt_2: Optional[Union[str, List[str]]] = None,
707
+ height: Optional[int] = None,
708
+ width: Optional[int] = None,
709
+ negative_prompt: Optional[Union[str, List[str]]] = None,
710
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
711
+ num_inference_steps: int = 8,
712
+ timesteps: List[int] = None,
713
+ eta: float = 0.0,
714
+ guidance_scale: float = 3.5,
715
+ num_images_per_prompt: Optional[int] = 1,
716
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
717
+ latents: Optional[torch.FloatTensor] = None,
718
+ prompt_embeds: Optional[torch.FloatTensor] = None,
719
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
720
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
721
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
722
+ output_type: Optional[str] = "pil",
723
+ return_dict: bool = True,
724
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
725
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
726
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
727
+ clip_skip: Optional[int] = None,
728
+ max_sequence_length: int = 300,
729
+ ):
730
+ height = height or self.default_sample_size * self.vae_scale_factor
731
+ width = width or self.default_sample_size * self.vae_scale_factor
732
+
733
+ # 1. Check inputs
734
+ self.check_inputs(
735
+ prompt,
736
+ prompt_2,
737
+ height,
738
+ width,
739
+ negative_prompt=negative_prompt,
740
+ negative_prompt_2=negative_prompt_2,
741
+ prompt_embeds=prompt_embeds,
742
+ negative_prompt_embeds=negative_prompt_embeds,
743
+ pooled_prompt_embeds=pooled_prompt_embeds,
744
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
745
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
746
+ max_sequence_length=max_sequence_length,
747
+ lora_scale=lora_scale
748
+ )
749
+
750
+ self._guidance_scale = guidance_scale
751
+ self._clip_skip = clip_skip
752
+ self._joint_attention_kwargs = joint_attention_kwargs
753
+ self._interrupt = False
754
+
755
+ # 2. Define call parameters
756
+ if prompt is not None and isinstance(prompt, str):
757
+ batch_size = 1
758
+ elif prompt is not None and isinstance(prompt, list):
759
+ batch_size = len(prompt)
760
+ else:
761
+ batch_size = prompt_embeds.shape[0]
762
+
763
+ device = self._execution_device
764
+
765
+ do_classifier_free_guidance = guidance_scale > 1.0
766
+
767
+ lora_scale = (
768
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
769
+ )
770
+
771
+ if self.do_classifier_free_guidance:
772
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
773
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
774
+
775
+ # 4. Prepare latent variables
776
+ num_channels_latents = self.transformer.config.in_channels // 4
777
+ latents, latent_image_ids = self.prepare_latents(
778
+ batch_size * num_images_per_prompt,
779
+ num_channels_latents,
780
+ height,
781
+ width,
782
+ prompt_embeds.dtype,
783
+ negative_prompt_embeds.dtype,
784
+ device,
785
+ generator,
786
+ latents,
787
+ )
788
+
789
+ # 5. Prepare timesteps
790
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
791
+ image_seq_len = latents.shape[1]
792
+ mu = calculate_timestep_shift(image_seq_len)
793
+ timesteps, num_inference_steps = prepare_timesteps(
794
+ self.scheduler,
795
+ num_inference_steps,
796
+ device,
797
+ timesteps,
798
+ sigmas,
799
+ mu=mu,
800
+ )
801
+ self._num_timesteps = len(timesteps)
802
+
803
+ # Handle guidance
804
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
805
+
806
  # 6. Denoising loop
807
  with self.progress_bar(total=num_inference_steps) as progress_bar:
808
  for i, t in enumerate(timesteps):
 
845
  # Yield intermediate result
846
  torch.cuda.empty_cache()
847
 
848
+ if latents.dtype != latents_dtype:
849
+ if torch.backends.mps.is_available():
850
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
851
+ latents = latents.to(latents_dtype)
852
 
853
+ if callback_on_step_end is not None:
854
+ callback_kwargs = {}
855
  for k in callback_on_step_end_tensor_inputs:
856
  callback_kwargs[k] = locals()[k]
857
  callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
858
 
859
+ latents = callback_outputs.pop("latents", latents)
860
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
861
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
862
  negative_pooled_prompt_embeds = callback_outputs.pop(
 
864
  )
865
 
866
  # call the callback, if provided
867
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
868
  progress_bar.update()
869
+ # Final image
870
+
871
  return self._decode_latents_to_image(latents, height, width, output_type)
872
  self.maybe_free_model_hooks()
873
  torch.cuda.empty_cache()