2d7efb8
1
2
3
4
5
6
7
8
9
import torch l2_criterion = torch.nn.MSELoss(reduction='mean') def l2_loss(real_images, generated_images): loss = l2_criterion(real_images, generated_images) return loss