Spaces:
Sleeping
Sleeping
update LatentNoiseTrainer class to include iteration callbacks
Browse files- training/trainer.py +4 -1
training/trainer.py
CHANGED
@@ -51,6 +51,7 @@ class LatentNoiseTrainer:
|
|
51 |
prompt: str,
|
52 |
optimizer: torch.optim.Optimizer,
|
53 |
save_dir: Optional[str] = None,
|
|
|
54 |
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
|
55 |
logging.info(f"Optimizing latents for prompt '{prompt}'.")
|
56 |
best_loss = torch.inf
|
@@ -120,6 +121,8 @@ class LatentNoiseTrainer:
|
|
120 |
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
121 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
122 |
image_pil.save(f"{save_dir}/{iteration}.png")
|
|
|
|
|
123 |
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
124 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
125 |
-
return image_pil, initial_rewards, best_rewards
|
|
|
51 |
prompt: str,
|
52 |
optimizer: torch.optim.Optimizer,
|
53 |
save_dir: Optional[str] = None,
|
54 |
+
progress_callback=None,
|
55 |
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
|
56 |
logging.info(f"Optimizing latents for prompt '{prompt}'.")
|
57 |
best_loss = torch.inf
|
|
|
121 |
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
122 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
123 |
image_pil.save(f"{save_dir}/{iteration}.png")
|
124 |
+
if progress_callback:
|
125 |
+
progress_callback(iteration + 1)
|
126 |
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
127 |
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
|
128 |
+
return image_pil, initial_rewards, best_rewards
|