Spaces:
Running
on
Zero
Running
on
Zero
AlekseyCalvin
commited on
Commit
•
08a702c
1
Parent(s):
18bd703
Update pipeline.py
Browse files- pipeline.py +12 -27
pipeline.py
CHANGED
@@ -625,35 +625,20 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
625 |
if torch.backends.mps.is_available():
|
626 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
627 |
latents = latents.to(latents_dtype)
|
628 |
-
|
629 |
-
if callback_on_step_end is not None:
|
630 |
-
callback_kwargs = {}
|
631 |
-
for k in callback_on_step_end_tensor_inputs:
|
632 |
-
callback_kwargs[k] = locals()[k]
|
633 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
634 |
-
|
635 |
-
latents = callback_outputs.pop("latents", latents)
|
636 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
637 |
|
|
|
638 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
639 |
progress_bar.update()
|
640 |
-
|
641 |
-
if XLA_AVAILABLE:
|
642 |
-
xm.mark_step()
|
643 |
-
|
644 |
-
if output_type == "latent":
|
645 |
-
image = latents
|
646 |
-
|
647 |
-
else:
|
648 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
649 |
-
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
650 |
-
image = self.vae.decode(latents, return_dict=False)[0]
|
651 |
-
image = self.image_processor.postprocess(image, output_type=output_type)
|
652 |
|
653 |
-
#
|
|
|
654 |
self.maybe_free_model_hooks()
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
|
|
|
|
|
|
|
|
|
625 |
if torch.backends.mps.is_available():
|
626 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
627 |
latents = latents.to(latents_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
|
629 |
+
# call the callback, if provided
|
630 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
631 |
progress_bar.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
+
# Final image
|
634 |
+
return self._decode_latents_to_image(latents, height, width, output_type)
|
635 |
self.maybe_free_model_hooks()
|
636 |
+
torch.cuda.empty_cache()
|
637 |
+
|
638 |
+
def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
|
639 |
+
"""Decodes the given latents into an image."""
|
640 |
+
vae = vae or self.vae
|
641 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
642 |
+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
643 |
+
image = vae.decode(latents, return_dict=False)[0]
|
644 |
+
return self.image_processor.postprocess(image, output_type=output_type)[0]
|