|
import torch |
|
import torch.nn as nn |
|
from torchvision.models import vgg19, vgg16 |
|
import os.path as osp |
|
|
|
|
|
|
|
|
|
|
|
class VGG_LOSS(nn.Module): |
|
def __init__(self, model_type='vgg19', layer_names=('conv_1_1', 'conv_2_1'), loss_type='l1'): |
|
super(VGG_LOSS, self).__init__() |
|
|
|
mdir = osp.dirname(osp.realpath(__file__)) |
|
if model_type== 'vgg16': |
|
vgg_model = vgg16(pretrained=False) |
|
pre_trained = torch.load('../vgg16-397923af.pth') |
|
vgg_model.cuda() |
|
vgg_model.load_state_dict(pre_trained) |
|
elif model_type== 'vgg19': |
|
vgg_model = vgg19(pretrained=False) |
|
pre_trained = torch.load('../vgg19-dcbb9e9d.pth') |
|
vgg_model.cuda() |
|
vgg_model.load_state_dict(pre_trained) |
|
|
|
|
|
self.layer_names = get_layer_name_id(model_type, layer_names) |
|
self.layer_ids = inverse_dict(self.layer_names) |
|
self.lid_list = list(self.layer_names.values()) |
|
self.lname_input = 'input' if ('input' in layer_names) else None |
|
|
|
|
|
lid_max = max(self.lid_list) |
|
self.network = vgg_model.features[:lid_max + 1] |
|
|
|
|
|
self.mean_shift = MeanShift() |
|
|
|
|
|
loss_fun = nn.L1Loss() |
|
if loss_type == 'l1': |
|
loss_fun = nn.L1Loss() |
|
elif loss_type == 'l2': |
|
loss_fun = nn.MSELoss() |
|
else: |
|
pass |
|
self.loss_fun = loss_fun |
|
|
|
|
|
self.set_not_requires_grad() |
|
return |
|
|
|
def forward(self, img_gt, img_infer, img_range=(-1.0, 1.0)): |
|
''' |
|
计算vgg损失 |
|
''' |
|
feas_gt = self.get_feas(img_gt, img_range) |
|
feas_infer = self.get_feas(img_infer, img_range) |
|
|
|
loss_total = 0 |
|
for lname, gt in feas_gt.items(): |
|
infer = feas_infer[lname] |
|
loss_tmp = self.loss_fun(gt, infer) |
|
loss_total = loss_total + loss_tmp |
|
return loss_total |
|
|
|
def get_feas(self, xx, in_range): |
|
''' |
|
获取中间特征 |
|
''' |
|
|
|
xx = reset_range(xx, in_range) |
|
xx = self.mean_shift(xx) |
|
|
|
|
|
out_feas = dict() |
|
if self.lname_input is not None: |
|
inname = self.lname_input |
|
out_feas[inname] = xx.clone() |
|
for lid, layer in enumerate(self.network): |
|
xx = layer(xx) |
|
if lid in self.lid_list: |
|
layer_name = self.layer_ids[lid] |
|
out_feas[layer_name] = xx.clone() |
|
return out_feas |
|
|
|
def set_not_requires_grad(self): |
|
for para in self.parameters(): |
|
para.requires_grad = False |
|
self.eval() |
|
return |
|
|
|
def reset_range(indata, in_range): |
|
''' |
|
将数据范围调整为0~1 |
|
''' |
|
minv, maxv = in_range |
|
midv = 1.0 / (maxv - minv) |
|
return (indata - minv) * midv |
|
|
|
def get_layer_name_id(vgg_type, lnames): |
|
''' |
|
根据层名称获取层编号 |
|
''' |
|
out_dict = dict() |
|
layer_id_dict = vgg_all_layers(vgg_type) |
|
for lname in lnames: |
|
lid = layer_id_dict[lname] |
|
out_dict[lname] = lid |
|
return out_dict |
|
|
|
def vgg_all_layers(vgg_type): |
|
''' |
|
获取vgg中间层名称及层号 |
|
''' |
|
vgg_layer_vgg19 = { |
|
'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4, |
|
'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9, |
|
'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'conv_3_4': 16, 'pool_3': 18, |
|
'conv_4_1': 19, 'conv_4_2': 21, 'conv_4_3': 23, 'conv_4_4': 25, 'pool_4': 27, |
|
'conv_5_1': 28, 'conv_5_2': 30, 'conv_5_3': 32, 'conv_5_4': 34, 'pool_5': 36 |
|
} |
|
vgg_layer_vgg16 = { |
|
'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4, |
|
'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9, |
|
'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'pool_3': 16, |
|
'conv_4_1': 17, 'conv_4_2': 19, 'conv_4_3': 21, 'pool_4': 23, |
|
'conv_5_1': 24, 'conv_5_2': 26, 'conv_5_3': 28, 'pool_5': 30 |
|
} |
|
|
|
if vgg_type=='vgg16': |
|
vgg_layer_dict = vgg_layer_vgg16 |
|
elif vgg_type=='vgg19': |
|
vgg_layer_dict = vgg_layer_vgg19 |
|
else: |
|
raise ValueError('Vgg network type should be either vgg16 or vgg19.') |
|
|
|
vgg_fea_dict = {} |
|
for lname, lindex in vgg_layer_dict.items(): |
|
vgg_fea_dict[lname] = lindex |
|
if 'conv' in lname: |
|
lname_relu = lname.replace('conv', 'relu') |
|
lindex_relu = lindex + 1 |
|
vgg_fea_dict[lname_relu] = lindex_relu |
|
return vgg_fea_dict |
|
|
|
def inverse_dict(inp_dict): |
|
''' |
|
交换字典的键值及键 |
|
''' |
|
out_dict = dict() |
|
for key, val in inp_dict.items(): |
|
out_dict[val] = key |
|
return out_dict |
|
|
|
class MeanShift(nn.Conv2d): |
|
''' |
|
固定参数卷积层,用于将普通RGB图像(范围0~1)转换为VGG输入格式 |
|
''' |
|
def __init__(self, rgb_mean=(0.485, 0.456, 0.406), rgb_std=(0.229, 0.224, 0.225)): |
|
super(MeanShift, self).__init__(in_channels=3, out_channels=3, kernel_size=1) |
|
std = torch.Tensor(rgb_std) |
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) |
|
self.bias.data = torch.Tensor(rgb_mean) / std |
|
return |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import torchvision.utils as tv_utils |
|
import cv2 |
|
import numpy as np |
|
import os |
|
import os.path as osp |
|
|
|
|
|
|
|
inp = r'D:\tmp\test\baboon.png' |
|
|
|
img_in = cv2.imread(inp) |
|
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2RGB).transpose(2,0,1) |
|
|
|
img_in = (np.float32(img_in) - 0.0) / 255.0 |
|
img_in_t = torch.from_numpy(img_in).unsqueeze(0) |
|
|
|
|
|
layer_names = ('conv_1_1', 'conv_2_1', 'conv_3_1', 'conv_4_1') |
|
vgg_test = VGG_LOSS(layer_names=layer_names) |
|
|
|
|
|
|
|
|
|
|
|
|
|
feas = vgg_test.get_feas(img_in_t, in_range=(0.0, 1.0)) |
|
|
|
|
|
def vgg_fea2img(vgg_fea): |
|
mid_feas = vgg_fea.data |
|
mid_feas = torch.transpose(mid_feas, 0, 1) |
|
fea_nrow = round((mid_feas.shape[0]) ** 0.5) |
|
fea_grid = tv_utils.make_grid(mid_feas, nrow=fea_nrow, normalize=True, scale_each=True) |
|
fea_grid = fea_grid.cpu().float().numpy().transpose((1, 2, 0)) |
|
fea_grid = (fea_grid * 255.0).round().clip(0, 255).astype(np.uint8) |
|
return fea_grid |
|
|
|
out_dir = osp.splitext(inp)[0] |
|
if not osp.exists(out_dir): |
|
os.mkdir(out_dir) |
|
bind = 0 |
|
for layer_name, mid_feas in feas.items(): |
|
out_fea = vgg_fea2img(mid_feas[bind:(bind + 1)]) |
|
|
|
out_path = osp.join(out_dir, '{}_{}.png'.format(bind, layer_name)) |
|
cv2.imwrite(out_path, out_fea) |