ReNO / training /trainer.py
fffiloni's picture
refactoring for Flux
dd8f929
raw
history blame
5.86 kB
import logging
import math
from typing import Dict, List, Optional, Tuple
import PIL
import PIL.Image
import torch
from diffusers import DiffusionPipeline
from rewards import clip_img_transform
from rewards.base_reward import BaseRewardLoss
class LatentNoiseTrainer:
"""Trainer for optimizing latents with reward losses."""
def __init__(
self,
reward_losses: List[BaseRewardLoss],
model: DiffusionPipeline,
n_iters: int,
n_inference_steps: int,
seed: int,
no_optim: bool = False,
regularize: bool = True,
regularization_weight: float = 0.01,
grad_clip: float = 0.1,
log_metrics: bool = True,
save_all_images: bool = False,
imageselect: bool = False,
device: torch.device = torch.device("cuda"),
):
self.reward_losses = reward_losses
self.model = model
self.n_iters = n_iters
self.n_inference_steps = n_inference_steps
self.seed = seed
self.no_optim = no_optim
self.regularize = regularize
self.regularization_weight = regularization_weight
self.grad_clip = grad_clip
self.log_metrics = log_metrics
self.save_all_images = save_all_images
self.imageselect = imageselect
self.device = device
self.preprocess_fn = clip_img_transform(224)
def train(
self,
latents: torch.Tensor,
prompt: str,
optimizer: torch.optim.Optimizer,
save_dir: Optional[str] = None,
multi_apply_fn=None,
progress_callback=None,
) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
logging.info(f"Optimizing latents for prompt '{prompt}'.")
best_loss = torch.inf
best_image = None
initial_image = None
initial_rewards = None
best_rewards = None
best_latents = None
latent_dim = math.prod(latents.shape[1:])
for iteration in range(self.n_iters):
to_log = ""
rewards = {}
optimizer.zero_grad()
generator = torch.Generator("cuda").manual_seed(self.seed)
if self.imageselect:
new_latents = torch.randn_like(
latents, device=self.device, dtype=latents.dtype
)
image = self.model.apply(
new_latents,
prompt,
generator=generator,
num_inference_steps=self.n_inference_steps,
)
else:
image = self.model.apply(
latents=latents,
prompt=prompt,
generator=generator,
num_inference_steps=self.n_inference_steps,
)
if initial_image is None and multi_apply_fn is not None:
multi_step_image = multi_apply_fn(latents.detach(), prompt)
image_numpy = (
multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
)
initial_image = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
if self.no_optim:
best_image = image
break
total_loss = 0
preprocessed_image = self.preprocess_fn(image)
for reward_loss in self.reward_losses:
loss = reward_loss(preprocessed_image, prompt)
to_log += f"{reward_loss.name}: {loss.item():.4f}, "
total_loss += loss * reward_loss.weighting
rewards[reward_loss.name] = loss.item()
rewards["total"] = total_loss.item()
to_log += f"Total: {total_loss.item():.4f}"
total_reward_loss = total_loss.item()
if self.regularize:
# compute in fp32 to avoid overflow
latent_norm = torch.linalg.vector_norm(latents).to(torch.float32)
log_norm = torch.log(latent_norm)
regularization = self.regularization_weight * (
0.5 * latent_norm**2 - (latent_dim - 1) * log_norm
)
to_log += f", Latent norm: {latent_norm.item()}"
rewards["norm"] = latent_norm.item()
total_loss += regularization.to(total_loss.dtype)
if self.log_metrics:
logging.info(f"Iteration {iteration}: {to_log}")
if total_reward_loss < best_loss:
best_loss = total_reward_loss
best_image = image
best_rewards = rewards
best_latents = latents.detach().cpu()
if iteration != self.n_iters - 1 and not self.imageselect:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(latents, self.grad_clip)
optimizer.step()
if self.save_all_images:
image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
image_pil.save(f"{save_dir}/{iteration}.png")
if initial_rewards is None:
initial_rewards = rewards
if progress_callback:
progress_callback(iteration + 1)
image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
if multi_apply_fn is not None:
multi_step_image = multi_apply_fn(best_latents.to("cuda"), prompt)
image_numpy = (
multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
)
best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
return initial_image, best_image_pil, initial_rewards, best_rewards