praeclarumjj3's picture
:zap: Build App
9eae6e7
raw
history blame
2.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.nn.functional import conv2d
class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
vgg16 = models.vgg16(pretrained=True)
self.enc_1 = nn.Sequential(*vgg16.features[:5])
self.enc_2 = nn.Sequential(*vgg16.features[5:10])
self.enc_3 = nn.Sequential(*vgg16.features[10:17])
# fix the encoder
for i in range(3):
for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
param.requires_grad = False
def forward(self, image):
results = [image]
for i in range(3):
func = getattr(self, 'enc_{:d}'.format(i + 1)).to(image.device)
results.append(func(results[-1]))
return results[1:]
class VGG19(nn.Module):
def __init__(self, resize_input=False):
super(VGG19, self).__init__()
features = models.vgg19(pretrained=True).features
self.resize_input = resize_input
self.mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
self.std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
prefix = [1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]
posfix = [1, 2, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]
names = list(zip(prefix, posfix))
self.relus = []
for pre, pos in names:
self.relus.append('relu{}_{}'.format(pre, pos))
self.__setattr__('relu{}_{}'.format(
pre, pos), torch.nn.Sequential())
nums = [[0, 1], [2, 3], [4, 5, 6], [7, 8],
[9, 10, 11], [12, 13], [14, 15], [16, 17],
[18, 19, 20], [21, 22], [23, 24], [25, 26],
[27, 28, 29], [30, 31], [32, 33], [34, 35]]
for i, layer in enumerate(self.relus):
for num in nums[i]:
self.__getattr__(layer).add_module(str(num), features[num])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
# resize and normalize input for pretrained vgg19
x = (x + 1.0) / 2.0
x = (x - self.mean.view(1, 3, 1, 1).to(x.device)) / (self.std.view(1, 3, 1, 1).to(x.device))
if self.resize_input:
x = F.interpolate(
x, size=(256, 256), mode='bilinear', align_corners=True)
features = []
for layer in self.relus:
x = self.__getattr__(layer).to(x.device)(x)
features.append(x)
out = {key: value for (key, value) in list(zip(self.relus, features))}
return out