Re-enable TQDM (reverts #11) (#17)
Browse files- Re-enable TQDM (reverts #11) (804e7faecd216ad98fdaa5a16ac0112dc1f8b79f)
Co-authored-by: Charles Bensimon <[email protected]>
- pipeline.py +44 -43
pipeline.py
CHANGED
@@ -398,50 +398,51 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
|
|
398 |
|
399 |
# 11. Denoising loop
|
400 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
)
|
406 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
407 |
-
|
408 |
-
if i <= start_merge_step:
|
409 |
-
current_prompt_embeds = torch.cat(
|
410 |
-
[negative_prompt_embeds, prompt_embeds_text_only], dim=0
|
411 |
-
)
|
412 |
-
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
|
413 |
-
else:
|
414 |
-
current_prompt_embeds = torch.cat(
|
415 |
-
[negative_prompt_embeds, prompt_embeds], dim=0
|
416 |
)
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
447 |
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
|
|
398 |
|
399 |
# 11. Denoising loop
|
400 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
401 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
402 |
+
for i, t in enumerate(timesteps):
|
403 |
+
latent_model_input = (
|
404 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
)
|
406 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
407 |
+
|
408 |
+
if i <= start_merge_step:
|
409 |
+
current_prompt_embeds = torch.cat(
|
410 |
+
[negative_prompt_embeds, prompt_embeds_text_only], dim=0
|
411 |
+
)
|
412 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
|
413 |
+
else:
|
414 |
+
current_prompt_embeds = torch.cat(
|
415 |
+
[negative_prompt_embeds, prompt_embeds], dim=0
|
416 |
+
)
|
417 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
418 |
+
# predict the noise residual
|
419 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
420 |
+
noise_pred = self.unet(
|
421 |
+
latent_model_input,
|
422 |
+
t,
|
423 |
+
encoder_hidden_states=current_prompt_embeds,
|
424 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
425 |
+
added_cond_kwargs=added_cond_kwargs,
|
426 |
+
return_dict=False,
|
427 |
+
)[0]
|
428 |
+
|
429 |
+
# perform guidance
|
430 |
+
if do_classifier_free_guidance:
|
431 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
432 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
433 |
+
|
434 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
435 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
436 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
437 |
+
|
438 |
+
# compute the previous noisy sample x_t -> x_t-1
|
439 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
440 |
+
|
441 |
+
# call the callback, if provided
|
442 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
443 |
+
progress_bar.update()
|
444 |
+
if callback is not None and i % callback_steps == 0:
|
445 |
+
callback(i, t, latents)
|
446 |
|
447 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
448 |
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|