File size: 14,577 Bytes
94ada0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from email import generator

from cv2 import DescriptorMatcher
import training
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from torch_utils import training_stats
from torch_utils import misc
from torch_utils.ops import conv2d_gradfix

#----------------------------------------------------------------------------

class Loss:
    def accumulate_gradients(self, **kwargs): # to be overridden by subclass
        raise NotImplementedError()

#----------------------------------------------------------------------------

class StyleGAN2Loss(Loss):
    def __init__(
        self, device, G_mapping, G_synthesis, D, 
        G_encoder=None, augment_pipe=None, D_ema=None,
        style_mixing_prob=0.9, r1_gamma=10, 
        pl_batch_shrink=2, pl_decay=0.01, pl_weight=2, other_weights=None,
        curriculum=None, alpha_start=0.0, cycle_consistency=False, label_smooth=0,
        generator_mode='random_z_random_c'):

        super().__init__()
        self.device            = device
        self.G_mapping         = G_mapping
        self.G_synthesis       = G_synthesis
        self.G_encoder         = G_encoder
        self.D                 = D
        self.D_ema             = D_ema
        self.augment_pipe      = augment_pipe
        self.style_mixing_prob = style_mixing_prob
        self.r1_gamma          = r1_gamma
        self.pl_batch_shrink   = pl_batch_shrink
        self.pl_decay          = pl_decay
        self.pl_weight         = pl_weight
        self.other_weights     = other_weights
        self.pl_mean           = torch.zeros([], device=device)
        self.curriculum        = curriculum
        self.alpha_start       = alpha_start
        self.alpha             = None
        self.cycle_consistency = cycle_consistency
        self.label_smooth      = label_smooth
        self.generator_mode    = generator_mode

        if self.G_encoder is not None:
            import lpips
            self.lpips_loss      = lpips.LPIPS(net='vgg').to(device=device)

    def set_alpha(self, steps):
        alpha = None
        if self.curriculum is not None:
            if self.curriculum == 'upsample':
                alpha = 0.0
            else:
                assert len(self.curriculum) == 2, "currently support one stage for now"
                start, end = self.curriculum
                alpha = min(1., max(0., (steps / 1e3 - start) / (end - start)))
                if self.alpha_start > 0:
                    alpha = self.alpha_start + (1 - self.alpha_start) * alpha
        self.alpha = alpha
        self.steps = steps
        self.curr_status = None

        def _apply(m):
            if hasattr(m, "set_alpha") and m != self:
                m.set_alpha(alpha)
            if hasattr(m, "set_steps") and m != self:
                m.set_steps(steps)
            if hasattr(m, "set_resolution") and m != self:
                m.set_resolution(self.curr_status)
        
        self.G_synthesis.apply(_apply)
        self.curr_status = self.resolution
        self.D.apply(_apply)
        if self.G_encoder is not None:
            self.G_encoder.apply(_apply)

    def run_G(self, z, c, sync, img=None, mode=None, get_loss=True):
        synthesis_kwargs = {'camera_mode': 'random'}
        generator_mode   = self.generator_mode if mode is None else mode

        if (generator_mode == 'image_z_random_c') or (generator_mode == 'image_z_image_c'):
            assert (self.G_encoder is not None) and (img is not None)
            with misc.ddp_sync(self.G_encoder, sync):
                ws  = self.G_encoder(img)['ws']
            if generator_mode == 'image_z_image_c':
                with misc.ddp_sync(self.D, False):
                    synthesis_kwargs['camera_RT'] = misc.get_func(self.D, 'get_estimated_camera')[0](img)
            with misc.ddp_sync(self.G_synthesis, sync):
                out = self.G_synthesis(ws, **synthesis_kwargs)            
            if get_loss:  # consistency loss given the image predicted camera (train the image encoder jointly)
                out['consist_l1_loss']    = F.smooth_l1_loss(out['img'], img['img']) * 2.0   # TODO: DEBUG
                out['consist_lpips_loss'] = self.lpips_loss(out['img'],  img['img']) * 10.0  # TODO: DEBUG
            
        elif (generator_mode == 'random_z_random_c') or (generator_mode == 'random_z_image_c'):
            with misc.ddp_sync(self.G_mapping, sync):
                ws  = self.G_mapping(z, c)
                if self.style_mixing_prob > 0:
                    with torch.autograd.profiler.record_function('style_mixing'):
                        cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
                        cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
                        ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
            if generator_mode == 'random_z_image_c':
                assert img is not None
                with torch.no_grad():
                    D = self.D_ema if self.D_ema is not None else self.D
                    with misc.ddp_sync(D, sync):
                        estimated_c = misc.get_func(D, 'get_estimated_camera')(img)[0].detach()
                        if estimated_c.size(-1) == 16:
                            synthesis_kwargs['camera_RT'] = estimated_c
                        if estimated_c.size(-1) == 3:
                            synthesis_kwargs['camera_UV'] = estimated_c
            with misc.ddp_sync(self.G_synthesis, sync):
                out = self.G_synthesis(ws, **synthesis_kwargs)
        else:
            raise NotImplementedError(f'wrong generator_mode {generator_mode}')
        return out, ws

    def run_D(self, img, c, sync):
        with misc.ddp_sync(self.D, sync):
            logits = self.D(img, c, aug_pipe=self.augment_pipe)
        return logits

    def get_loss(self, outputs, module='D'):
        reg_loss, logs, del_keys = 0, [], []
        if isinstance(outputs, dict):
            for key in outputs:
                if key[-5:] == '_loss':
                    logs += [(f'Loss/{module}/{key}', outputs[key])]
                    del_keys += [key]
                    if (self.other_weights is not None) and (key in self.other_weights):
                        reg_loss = reg_loss + outputs[key].mean() * self.other_weights[key]
                    else:
                        reg_loss = reg_loss + outputs[key].mean()
            for key in del_keys:
                del outputs[key]
            for key, loss in logs:
                training_stats.report(key, loss)
        return reg_loss

    @property
    def resolution(self):
        return misc.get_func(self.G_synthesis, 'get_current_resolution')()[-1]

    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, fake_img, sync, gain, scaler=None):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        do_Gmain = (phase in ['Gmain', 'Gboth'])
        do_Dmain = (phase in ['Dmain', 'Dboth'])
        do_Gpl   = (phase in ['Greg', 'Gboth']) 
        do_Dr1   = (phase in ['Dreg', 'Dboth'])
        losses   = {}

        # Gmain: Maximize logits for generated images.
        loss_Gmain, reg_loss = 0, 0
        if isinstance(fake_img, dict): fake_img = fake_img['img']
        if do_Gmain:
            with torch.autograd.profiler.record_function('Gmain_forward'):
                gen_img, gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl), img=fake_img)   # May get synced by Gpl.
                reg_loss  += self.get_loss(gen_img, 'G')
                gen_logits = self.run_D(gen_img, gen_c, sync=False)
                reg_loss  += self.get_loss(gen_logits, 'G')
                if isinstance(gen_logits, dict):
                    gen_logits = gen_logits['logits']
                    
                loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
                if self.label_smooth > 0:
                    loss_Gmain = loss_Gmain * (1 - self.label_smooth) +  torch.nn.functional.softplus(gen_logits) * self.label_smooth
                
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake', gen_logits.sign())
                training_stats.report('Loss/G/loss', loss_Gmain)

            with torch.autograd.profiler.record_function('Gmain_backward'):
                loss_Gmain  = loss_Gmain + reg_loss
                losses['Gmain'] = loss_Gmain.mean().mul(gain)
                loss = scaler.scale(losses['Gmain']) if scaler is not None else losses['Gmain']
                loss.backward()

        # Gpl: Apply path length regularization.
        if do_Gpl and (self.pl_weight != 0):
            with torch.autograd.profiler.record_function('Gpl_forward'):
                batch_size = max(1, gen_z.shape[0] // self.pl_batch_shrink)
                gen_img, gen_ws = self.run_G(
                    gen_z[:batch_size], gen_c[:batch_size], sync=sync, 
                    img=fake_img[:batch_size] if fake_img is not None else None)
                if isinstance(gen_img, dict):  gen_img = gen_img['img']
                pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
                with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
                # with torch.autograd.profiler.record_function('pl_grads'):
                    pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True, allow_unused=True)[0]
                pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
                pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
                self.pl_mean.copy_(pl_mean.detach())
                pl_penalty = (pl_lengths - pl_mean).square()
                training_stats.report('Loss/pl_penalty', pl_penalty)
                loss_Gpl = pl_penalty * self.pl_weight
                training_stats.report('Loss/G/reg', loss_Gpl)

            with torch.autograd.profiler.record_function('Gpl_backward'):
                losses['Gpl'] = (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain)
                loss = scaler.scale(losses['Gpl']) if scaler is not None else losses['Gpl']
                loss.backward()

        # Dmain: Minimize logits for generated images.
        loss_Dgen, reg_loss = 0, 0
        if do_Dmain:
            with torch.autograd.profiler.record_function('Dgen_forward'):
                gen_img    = self.run_G(gen_z, gen_c, sync=False, img=fake_img)[0]                
                reg_loss  += self.get_loss(gen_img, 'D')
                gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
                reg_loss  += self.get_loss(gen_logits, 'D')
                if isinstance(gen_logits, dict):
                    gen_logits = gen_logits['logits']
                   
                training_stats.report('Loss/scores/fake', gen_logits)
                training_stats.report('Loss/signs/fake',  gen_logits.sign())
                loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))

            with torch.autograd.profiler.record_function('Dgen_backward'):
                loss_Dgen  = loss_Dgen + reg_loss
                losses['Dgen'] = loss_Dgen.mean().mul(gain)
                loss = scaler.scale(losses['Dgen']) if scaler is not None else losses['Dgen']
                loss.backward()

        # Dmain: Maximize logits for real images.
        # Dr1: Apply R1 regularization.
        if do_Dmain or (do_Dr1 and (self.r1_gamma != 0)):
            name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
            with torch.autograd.profiler.record_function(name + '_forward'):
                if isinstance(real_img, dict):
                    real_img['img'] = real_img['img'].requires_grad_(do_Dr1)
                else:
                    real_img = real_img.requires_grad_(do_Dr1)
                real_logits = self.run_D(real_img, real_c, sync=sync)
                if isinstance(real_logits, dict):
                    real_logits = real_logits['logits']

                training_stats.report('Loss/scores/real', real_logits)
                training_stats.report('Loss/signs/real',  real_logits.sign())

                loss_Dreal = 0
                if do_Dmain:
                    loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
                    if self.label_smooth > 0:
                        loss_Dreal = loss_Dreal * (1 - self.label_smooth) +  torch.nn.functional.softplus(real_logits) * self.label_smooth
                    
                    training_stats.report('Loss/D/loss', loss_Dgen.mean() + loss_Dreal.mean())

                loss_Dr1 = 0
                if do_Dr1:
                    with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
                        real_img_tmp = real_img['img'] if isinstance(real_img, dict) else real_img
                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
                    r1_penalty = r1_grads.square().sum([1,2,3])
                    loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
                    training_stats.report('Loss/r1_penalty', r1_penalty)
                    training_stats.report('Loss/D/reg', loss_Dr1)

            with torch.autograd.profiler.record_function(name + '_backward'):
                losses['Dr1'] = (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain)
                loss = scaler.scale(losses['Dr1']) if scaler is not None else losses['Dr1']
                loss.backward()

        return losses

#----------------------------------------------------------------------------