fffiloni commited on
Commit
32ad50d
1 Parent(s): 6b14aab

Create inversion.py

Browse files
Files changed (1) hide show
  1. stylegan2/inversion.py +206 -0
stylegan2/inversion.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import optim
3
+ from torch.nn import functional as FF
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import dataclasses
8
+
9
+ from .lpips import util
10
+
11
+
12
+ def noise_regularize(noises):
13
+ loss = 0
14
+
15
+ for noise in noises:
16
+ size = noise.shape[2]
17
+
18
+ while True:
19
+ loss = (
20
+ loss
21
+ + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
22
+ + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
23
+ )
24
+
25
+ if size <= 8:
26
+ break
27
+
28
+ noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
29
+ noise = noise.mean([3, 5])
30
+ size //= 2
31
+
32
+ return loss
33
+
34
+
35
+ def noise_normalize_(noises):
36
+ for noise in noises:
37
+ mean = noise.mean()
38
+ std = noise.std()
39
+
40
+ noise.data.add_(-mean).div_(std)
41
+
42
+
43
+ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
44
+ lr_ramp = min(1, (1 - t) / rampdown)
45
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
46
+ lr_ramp = lr_ramp * min(1, t / rampup)
47
+
48
+ return initial_lr * lr_ramp
49
+
50
+
51
+ def latent_noise(latent, strength):
52
+ noise = torch.randn_like(latent) * strength
53
+
54
+ return latent + noise
55
+
56
+
57
+ def make_image(tensor):
58
+ return (
59
+ tensor.detach()
60
+ .clamp_(min=-1, max=1)
61
+ .add(1)
62
+ .div_(2)
63
+ .mul(255)
64
+ .type(torch.uint8)
65
+ .permute(0, 2, 3, 1)
66
+ .to("cpu")
67
+ .numpy()
68
+ )
69
+
70
+
71
+ @dataclasses.dataclass
72
+ class InverseConfig:
73
+ lr_warmup = 0.05
74
+ lr_decay = 0.25
75
+ lr = 0.1
76
+ noise = 0.05
77
+ noise_decay = 0.75
78
+ step = 1000
79
+ noise_regularize = 1e5
80
+ mse = 0
81
+ w_plus = False,
82
+
83
+
84
+ def inverse_image(
85
+ g_ema,
86
+ image,
87
+ image_size=256,
88
+ config=InverseConfig()
89
+ ):
90
+ device = "cuda"
91
+ args = config
92
+
93
+ n_mean_latent = 10000
94
+
95
+ resize = min(image_size, 256)
96
+
97
+ transform = transforms.Compose(
98
+ [
99
+ transforms.Resize(resize),
100
+ transforms.CenterCrop(resize),
101
+ transforms.ToTensor(),
102
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
103
+ ]
104
+ )
105
+
106
+ imgs = []
107
+ img = transform(image)
108
+ imgs.append(img)
109
+
110
+ imgs = torch.stack(imgs, 0).to(device)
111
+
112
+ with torch.no_grad():
113
+ noise_sample = torch.randn(n_mean_latent, 512, device=device)
114
+ latent_out = g_ema.style(noise_sample)
115
+
116
+ latent_mean = latent_out.mean(0)
117
+ latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
118
+
119
+ percept = util.PerceptualLoss(
120
+ model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
121
+ )
122
+
123
+ noises_single = g_ema.make_noise()
124
+ noises = []
125
+ for noise in noises_single:
126
+ noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
127
+
128
+ latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
129
+
130
+ if args.w_plus:
131
+ latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
132
+
133
+ latent_in.requires_grad = True
134
+
135
+ for noise in noises:
136
+ noise.requires_grad = True
137
+
138
+ optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
139
+
140
+ pbar = tqdm(range(args.step))
141
+ latent_path = []
142
+
143
+ for i in pbar:
144
+ t = i / args.step
145
+ lr = get_lr(t, args.lr)
146
+ optimizer.param_groups[0]["lr"] = lr
147
+ noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
148
+ latent_n = latent_noise(latent_in, noise_strength.item())
149
+
150
+ latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
151
+ img_gen, F = g_ema.generate(latent, noise)
152
+
153
+ batch, channel, height, width = img_gen.shape
154
+
155
+ if height > 256:
156
+ factor = height // 256
157
+
158
+ img_gen = img_gen.reshape(
159
+ batch, channel, height // factor, factor, width // factor, factor
160
+ )
161
+ img_gen = img_gen.mean([3, 5])
162
+
163
+ p_loss = percept(img_gen, imgs).sum()
164
+ n_loss = noise_regularize(noises)
165
+ mse_loss = FF.mse_loss(img_gen, imgs)
166
+
167
+ loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
168
+
169
+ optimizer.zero_grad()
170
+ loss.backward()
171
+ optimizer.step()
172
+
173
+ noise_normalize_(noises)
174
+
175
+ if (i + 1) % 100 == 0:
176
+ latent_path.append(latent_in.detach().clone())
177
+
178
+ pbar.set_description(
179
+ (
180
+ f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
181
+ f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
182
+ )
183
+ )
184
+
185
+ latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
186
+ img_gen, F = g_ema.generate(latent, noise)
187
+
188
+ img_ar = make_image(img_gen)
189
+
190
+ i = 0
191
+
192
+ noise_single = []
193
+ for noise in noises:
194
+ noise_single.append(noise[i: i + 1])
195
+
196
+ result = {
197
+ "latent": latent,
198
+ "noise": noise_single,
199
+ 'F': F,
200
+ "sample": img_gen,
201
+ }
202
+
203
+ pil_img = Image.fromarray(img_ar[i])
204
+ pil_img.save('project.png')
205
+
206
+ return result