|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from generative.losses import PatchAdversarialLoss |
|
|
|
intensity_loss = torch.nn.L1Loss() |
|
adv_loss = PatchAdversarialLoss(criterion="least_squares") |
|
|
|
adv_weight = 0.5 |
|
perceptual_weight = 1.0 |
|
|
|
|
|
|
|
kl_weight = 1e-6 |
|
|
|
|
|
def compute_kl_loss(z_mu, z_sigma): |
|
kl_loss = 0.5 * torch.sum( |
|
z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=list(range(1, len(z_sigma.shape))) |
|
) |
|
return torch.sum(kl_loss) / kl_loss.shape[0] |
|
|
|
|
|
def generator_loss(gen_images, real_images, z_mu, z_sigma, disc_net, loss_perceptual): |
|
recons_loss = intensity_loss(gen_images, real_images) |
|
kl_loss = compute_kl_loss(z_mu, z_sigma) |
|
p_loss = loss_perceptual(gen_images.float(), real_images.float()) |
|
loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss |
|
|
|
logits_fake = disc_net(gen_images)[-1] |
|
generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False) |
|
loss_g = loss_g + adv_weight * generator_loss |
|
|
|
return loss_g |
|
|
|
|
|
def discriminator_loss(gen_images, real_images, disc_net): |
|
logits_fake = disc_net(gen_images.contiguous().detach())[-1] |
|
loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True) |
|
logits_real = disc_net(real_images.contiguous().detach())[-1] |
|
loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True) |
|
discriminator_loss = (loss_d_fake + loss_d_real) * 0.5 |
|
loss_d = adv_weight * discriminator_loss |
|
return loss_d |
|
|