Arnaudding001 commited on
Commit
cd6cde5
1 Parent(s): 355bf1a

Create vgg.py

Browse files
Files changed (1) hide show
  1. vgg.py +60 -0
vgg.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+
5
+ # VGG architecter, used for the perceptual loss using a pretrained VGG network
6
+ class VGG19(torch.nn.Module):
7
+ def __init__(self, requires_grad=False):
8
+ super().__init__()
9
+ vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ for x in range(2):
17
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
18
+ for x in range(2, 7):
19
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
20
+ for x in range(7, 12):
21
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
22
+ for x in range(12, 21):
23
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
24
+ for x in range(21, 32):
25
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
26
+ for x in range(32, 36):
27
+ self.slice6.add_module(str(x), vgg_pretrained_features[x])
28
+ if not requires_grad:
29
+ for param in self.parameters():
30
+ param.requires_grad = False
31
+
32
+ self.pool = nn.AdaptiveAvgPool2d(output_size=1)
33
+
34
+ self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,-1, 1, 1).cuda() * 2 - 1
35
+ self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,-1, 1, 1).cuda() * 2
36
+
37
+ def forward(self, X): # relui_1
38
+ X = (X-self.mean)/self.std
39
+ h_relu1 = self.slice1(X)
40
+ h_relu2 = self.slice2(h_relu1)
41
+ h_relu3 = self.slice3(h_relu2)
42
+ h_relu4 = self.slice4(h_relu3)
43
+ h_relu5 = self.slice5[:-2](h_relu4)
44
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
45
+ return out
46
+
47
+ # Perceptual loss that uses a pretrained VGG network
48
+ class VGGLoss(nn.Module):
49
+ def __init__(self):
50
+ super(VGGLoss, self).__init__()
51
+ self.vgg = VGG19().cuda()
52
+ self.criterion = nn.L1Loss()
53
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
54
+
55
+ def forward(self, x, y):
56
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
57
+ loss = 0
58
+ for i in range(len(x_vgg)):
59
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
60
+ return loss