File size: 5,858 Bytes
ca25718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8f929
c6e7b6c
ca25718
 
 
 
dd8f929
ca25718
 
dd8f929
ca25718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8f929
 
ca25718
 
 
dd8f929
 
 
 
 
 
ca25718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8f929
ca25718
 
 
 
 
 
 
 
dd8f929
 
c6e7b6c
 
ca25718
dd8f929
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import logging
import math
from typing import Dict, List, Optional, Tuple

import PIL
import PIL.Image
import torch
from diffusers import DiffusionPipeline

from rewards import clip_img_transform
from rewards.base_reward import BaseRewardLoss


class LatentNoiseTrainer:
    """Trainer for optimizing latents with reward losses."""

    def __init__(
        self,
        reward_losses: List[BaseRewardLoss],
        model: DiffusionPipeline,
        n_iters: int,
        n_inference_steps: int,
        seed: int,
        no_optim: bool = False,
        regularize: bool = True,
        regularization_weight: float = 0.01,
        grad_clip: float = 0.1,
        log_metrics: bool = True,
        save_all_images: bool = False,
        imageselect: bool = False,
        device: torch.device = torch.device("cuda"),
    ):
        self.reward_losses = reward_losses
        self.model = model
        self.n_iters = n_iters
        self.n_inference_steps = n_inference_steps
        self.seed = seed
        self.no_optim = no_optim
        self.regularize = regularize
        self.regularization_weight = regularization_weight
        self.grad_clip = grad_clip
        self.log_metrics = log_metrics
        self.save_all_images = save_all_images
        self.imageselect = imageselect
        self.device = device
        self.preprocess_fn = clip_img_transform(224)

    def train(
        self,
        latents: torch.Tensor,
        prompt: str,
        optimizer: torch.optim.Optimizer,
        save_dir: Optional[str] = None,
        multi_apply_fn=None,
        progress_callback=None,
    ) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
        logging.info(f"Optimizing latents for prompt '{prompt}'.")
        best_loss = torch.inf
        best_image = None
        initial_image = None
        initial_rewards = None
        best_rewards = None
        best_latents = None
        latent_dim = math.prod(latents.shape[1:])
        for iteration in range(self.n_iters):
            to_log = ""
            rewards = {}
            optimizer.zero_grad()
            generator = torch.Generator("cuda").manual_seed(self.seed)
            if self.imageselect:
                new_latents = torch.randn_like(
                    latents, device=self.device, dtype=latents.dtype
                )
                image = self.model.apply(
                    new_latents,
                    prompt,
                    generator=generator,
                    num_inference_steps=self.n_inference_steps,
                )
            else:
                image = self.model.apply(
                    latents=latents,
                    prompt=prompt,
                    generator=generator,
                    num_inference_steps=self.n_inference_steps,
                )
            if initial_image is None and multi_apply_fn is not None:
                multi_step_image = multi_apply_fn(latents.detach(), prompt)
                image_numpy = (
                    multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
                )
                initial_image = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
            if self.no_optim:
                best_image = image
                break

            total_loss = 0
            preprocessed_image = self.preprocess_fn(image)
            for reward_loss in self.reward_losses:
                loss = reward_loss(preprocessed_image, prompt)
                to_log += f"{reward_loss.name}: {loss.item():.4f}, "
                total_loss += loss * reward_loss.weighting
                rewards[reward_loss.name] = loss.item()
            rewards["total"] = total_loss.item()
            to_log += f"Total: {total_loss.item():.4f}"
            total_reward_loss = total_loss.item()
            if self.regularize:
                # compute in fp32 to avoid overflow
                latent_norm = torch.linalg.vector_norm(latents).to(torch.float32)
                log_norm = torch.log(latent_norm)
                regularization = self.regularization_weight * (
                    0.5 * latent_norm**2 - (latent_dim - 1) * log_norm
                )
                to_log += f", Latent norm: {latent_norm.item()}"
                rewards["norm"] = latent_norm.item()
                total_loss += regularization.to(total_loss.dtype)
            if self.log_metrics:
                logging.info(f"Iteration {iteration}: {to_log}")
            if total_reward_loss < best_loss:
                best_loss = total_reward_loss
                best_image = image
                best_rewards = rewards
                best_latents = latents.detach().cpu()
            if iteration != self.n_iters - 1 and not self.imageselect:
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(latents, self.grad_clip)
                optimizer.step()
            if self.save_all_images:
                image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
                image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
                image_pil.save(f"{save_dir}/{iteration}.png")
            if initial_rewards is None:
                initial_rewards = rewards
            if progress_callback:
                progress_callback(iteration + 1)
        image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
        best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
        if multi_apply_fn is not None:
            multi_step_image = multi_apply_fn(best_latents.to("cuda"), prompt)
            image_numpy = (
                multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
            )
            best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
        return initial_image, best_image_pil, initial_rewards, best_rewards