ReNO / training /trainer.py
fffiloni's picture
refactoring for Flux
dd8f929
import logging
import math
from typing import Dict, List, Optional, Tuple
import PIL
import PIL.Image
import torch
from diffusers import DiffusionPipeline
from rewards import clip_img_transform
from rewards.base_reward import BaseRewardLoss
class LatentNoiseTrainer:
"""Trainer for optimizing latents with reward losses."""
def __init__(
self,
reward_losses: List[BaseRewardLoss],
model: DiffusionPipeline,
n_iters: int,
n_inference_steps: int,
seed: int,
no_optim: bool = False,
regularize: bool = True,
regularization_weight: float = 0.01,
grad_clip: float = 0.1,
log_metrics: bool = True,
save_all_images: bool = False,
imageselect: bool = False,
device: torch.device = torch.device("cuda"),
):
self.reward_losses = reward_losses
self.model = model
self.n_iters = n_iters
self.n_inference_steps = n_inference_steps
self.seed = seed
self.no_optim = no_optim
self.regularize = regularize
self.regularization_weight = regularization_weight
self.grad_clip = grad_clip
self.log_metrics = log_metrics
self.save_all_images = save_all_images
self.imageselect = imageselect
self.device = device
self.preprocess_fn = clip_img_transform(224)
def train(
self,
latents: torch.Tensor,
prompt: str,
optimizer: torch.optim.Optimizer,
save_dir: Optional[str] = None,
multi_apply_fn=None,
progress_callback=None,
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
logging.info(f"Optimizing latents for prompt '{prompt}'.")
best_loss = torch.inf
best_image = None
initial_image = None
initial_rewards = None
best_rewards = None
best_latents = None
latent_dim = math.prod(latents.shape[1:])
for iteration in range(self.n_iters):
to_log = ""
rewards = {}
optimizer.zero_grad()
generator = torch.Generator("cuda").manual_seed(self.seed)
if self.imageselect:
new_latents = torch.randn_like(
latents, device=self.device, dtype=latents.dtype
)
image = self.model.apply(
new_latents,
prompt,
generator=generator,
num_inference_steps=self.n_inference_steps,
)
else:
image = self.model.apply(
latents=latents,
prompt=prompt,
generator=generator,
num_inference_steps=self.n_inference_steps,
)
if initial_image is None and multi_apply_fn is not None:
multi_step_image = multi_apply_fn(latents.detach(), prompt)
image_numpy = (
multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
)
initial_image = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
if self.no_optim:
best_image = image
break
total_loss = 0
preprocessed_image = self.preprocess_fn(image)
for reward_loss in self.reward_losses:
loss = reward_loss(preprocessed_image, prompt)
to_log += f"{reward_loss.name}: {loss.item():.4f}, "
total_loss += loss * reward_loss.weighting
rewards[reward_loss.name] = loss.item()
rewards["total"] = total_loss.item()
to_log += f"Total: {total_loss.item():.4f}"
total_reward_loss = total_loss.item()
if self.regularize:
# compute in fp32 to avoid overflow
latent_norm = torch.linalg.vector_norm(latents).to(torch.float32)
log_norm = torch.log(latent_norm)
regularization = self.regularization_weight * (
0.5 * latent_norm**2 - (latent_dim - 1) * log_norm
)
to_log += f", Latent norm: {latent_norm.item()}"
rewards["norm"] = latent_norm.item()
total_loss += regularization.to(total_loss.dtype)
if self.log_metrics:
logging.info(f"Iteration {iteration}: {to_log}")
if total_reward_loss < best_loss:
best_loss = total_reward_loss
best_image = image
best_rewards = rewards
best_latents = latents.detach().cpu()
if iteration != self.n_iters - 1 and not self.imageselect:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(latents, self.grad_clip)
optimizer.step()
if self.save_all_images:
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
image_pil.save(f"{save_dir}/{iteration}.png")
if initial_rewards is None:
initial_rewards = rewards
if progress_callback:
progress_callback(iteration + 1)
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
if multi_apply_fn is not None:
multi_step_image = multi_apply_fn(best_latents.to("cuda"), prompt)
image_numpy = (
multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
)
best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
return initial_image, best_image_pil, initial_rewards, best_rewards