AlekseyCalvin commited on
Commit
08a702c
1 Parent(s): 18bd703

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +12 -27
pipeline.py CHANGED
@@ -625,35 +625,20 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
625
  if torch.backends.mps.is_available():
626
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
627
  latents = latents.to(latents_dtype)
628
-
629
- if callback_on_step_end is not None:
630
- callback_kwargs = {}
631
- for k in callback_on_step_end_tensor_inputs:
632
- callback_kwargs[k] = locals()[k]
633
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
634
-
635
- latents = callback_outputs.pop("latents", latents)
636
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
637
 
 
638
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
639
  progress_bar.update()
640
-
641
- if XLA_AVAILABLE:
642
- xm.mark_step()
643
-
644
- if output_type == "latent":
645
- image = latents
646
-
647
- else:
648
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
649
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
650
- image = self.vae.decode(latents, return_dict=False)[0]
651
- image = self.image_processor.postprocess(image, output_type=output_type)
652
 
653
- # Offload all models
 
654
  self.maybe_free_model_hooks()
655
-
656
- if not return_dict:
657
- return (image,)
658
-
659
- return FluxPipelineOutput(images=image)
 
 
 
 
 
625
  if torch.backends.mps.is_available():
626
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
627
  latents = latents.to(latents_dtype)
 
 
 
 
 
 
 
 
 
628
 
629
+ # call the callback, if provided
630
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
631
  progress_bar.update()
 
 
 
 
 
 
 
 
 
 
 
 
632
 
633
+ # Final image
634
+ return self._decode_latents_to_image(latents, height, width, output_type)
635
  self.maybe_free_model_hooks()
636
+ torch.cuda.empty_cache()
637
+
638
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
639
+ """Decodes the given latents into an image."""
640
+ vae = vae or self.vae
641
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
642
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
643
+ image = vae.decode(latents, return_dict=False)[0]
644
+ return self.image_processor.postprocess(image, output_type=output_type)[0]