NPRC24 / DH-AISP /2 /vgg_loss.py
Artyom
dh-aisp
bd1c686 verified
raw
history blame
7.51 kB
import torch
import torch.nn as nn
from torchvision.models import vgg19, vgg16
import os.path as osp
#-------------------------------------------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------------------------------------------#
# **** VGG损失模块
class VGG_LOSS(nn.Module):
def __init__(self, model_type='vgg19', layer_names=('conv_1_1', 'conv_2_1'), loss_type='l1'):
super(VGG_LOSS, self).__init__()
# **** 加载vgg模型
mdir = osp.dirname(osp.realpath(__file__))
if model_type== 'vgg16':
vgg_model = vgg16(pretrained=False)
pre_trained = torch.load('../vgg16-397923af.pth')
vgg_model.cuda()
vgg_model.load_state_dict(pre_trained)
elif model_type== 'vgg19':
vgg_model = vgg19(pretrained=False)
pre_trained = torch.load('../vgg19-dcbb9e9d.pth')
vgg_model.cuda()
vgg_model.load_state_dict(pre_trained)
# **** 层名称及层编号
self.layer_names = get_layer_name_id(model_type, layer_names)
self.layer_ids = inverse_dict(self.layer_names)
self.lid_list = list(self.layer_names.values())
self.lname_input = 'input' if ('input' in layer_names) else None
# **** 截断模型
lid_max = max(self.lid_list)
self.network = vgg_model.features[:lid_max + 1]
# **** 输入图像正则化层
self.mean_shift = MeanShift()
# **** vgg特征损失函数
loss_fun = nn.L1Loss()
if loss_type == 'l1':
loss_fun = nn.L1Loss()
elif loss_type == 'l2':
loss_fun = nn.MSELoss()
else:
pass
self.loss_fun = loss_fun
# **** 固定参数
self.set_not_requires_grad()
return
def forward(self, img_gt, img_infer, img_range=(-1.0, 1.0)):
'''
计算vgg损失
'''
feas_gt = self.get_feas(img_gt, img_range)
feas_infer = self.get_feas(img_infer, img_range)
loss_total = 0
for lname, gt in feas_gt.items():
infer = feas_infer[lname]
loss_tmp = self.loss_fun(gt, infer)
loss_total = loss_total + loss_tmp
return loss_total
def get_feas(self, xx, in_range):
'''
获取中间特征
'''
# **** 调整输入
xx = reset_range(xx, in_range)
xx = self.mean_shift(xx)
# **** 获取中间特征
out_feas = dict()
if self.lname_input is not None:
inname = self.lname_input
out_feas[inname] = xx.clone()
for lid, layer in enumerate(self.network):
xx = layer(xx)
if lid in self.lid_list:
layer_name = self.layer_ids[lid]
out_feas[layer_name] = xx.clone()
return out_feas
def set_not_requires_grad(self):
for para in self.parameters():
para.requires_grad = False
self.eval()
return
def reset_range(indata, in_range):
'''
将数据范围调整为0~1
'''
minv, maxv = in_range
midv = 1.0 / (maxv - minv)
return (indata - minv) * midv
def get_layer_name_id(vgg_type, lnames):
'''
根据层名称获取层编号
'''
out_dict = dict()
layer_id_dict = vgg_all_layers(vgg_type)
for lname in lnames:
lid = layer_id_dict[lname]
out_dict[lname] = lid
return out_dict
def vgg_all_layers(vgg_type):
'''
获取vgg中间层名称及层号
'''
vgg_layer_vgg19 = {
'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4,
'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9,
'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'conv_3_4': 16, 'pool_3': 18,
'conv_4_1': 19, 'conv_4_2': 21, 'conv_4_3': 23, 'conv_4_4': 25, 'pool_4': 27,
'conv_5_1': 28, 'conv_5_2': 30, 'conv_5_3': 32, 'conv_5_4': 34, 'pool_5': 36
}
vgg_layer_vgg16 = {
'conv_1_1': 0, 'conv_1_2': 2, 'pool_1': 4,
'conv_2_1': 5, 'conv_2_2': 7, 'pool_2': 9,
'conv_3_1': 10, 'conv_3_2': 12, 'conv_3_3': 14, 'pool_3': 16,
'conv_4_1': 17, 'conv_4_2': 19, 'conv_4_3': 21, 'pool_4': 23,
'conv_5_1': 24, 'conv_5_2': 26, 'conv_5_3': 28, 'pool_5': 30
}
if vgg_type=='vgg16':
vgg_layer_dict = vgg_layer_vgg16
elif vgg_type=='vgg19':
vgg_layer_dict = vgg_layer_vgg19
else:
raise ValueError('Vgg network type should be either vgg16 or vgg19.')
vgg_fea_dict = {}
for lname, lindex in vgg_layer_dict.items():
vgg_fea_dict[lname] = lindex
if 'conv' in lname:
lname_relu = lname.replace('conv', 'relu')
lindex_relu = lindex + 1
vgg_fea_dict[lname_relu] = lindex_relu
return vgg_fea_dict
def inverse_dict(inp_dict):
'''
交换字典的键值及键
'''
out_dict = dict()
for key, val in inp_dict.items():
out_dict[val] = key
return out_dict
class MeanShift(nn.Conv2d):
'''
固定参数卷积层,用于将普通RGB图像(范围0~1)转换为VGG输入格式
'''
def __init__(self, rgb_mean=(0.485, 0.456, 0.406), rgb_std=(0.229, 0.224, 0.225)):
super(MeanShift, self).__init__(in_channels=3, out_channels=3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = torch.Tensor(rgb_mean) / std
return
#-------------------------------------------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------------------------------------------#
if __name__ == '__main__':
import torchvision.utils as tv_utils
import cv2
import numpy as np
import os
import os.path as osp
# **** 参数
inp = r'D:\tmp\test\baboon.png'
img_in = cv2.imread(inp)
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2RGB).transpose(2,0,1)
# img_in = (np.float32(img_in) - 127.5) / 127.5
img_in = (np.float32(img_in) - 0.0) / 255.0
img_in_t = torch.from_numpy(img_in).unsqueeze(0)
# **** 测试
layer_names = ('conv_1_1', 'conv_2_1', 'conv_3_1', 'conv_4_1')
vgg_test = VGG_LOSS(layer_names=layer_names)
# # **** 检查参数是否已冻结
# for lname, paras in vgg_test.named_parameters():
# print(lname, paras.requires_grad)
# **** 获取特征
feas = vgg_test.get_feas(img_in_t, in_range=(0.0, 1.0))
# **** 保存特征为图像
def vgg_fea2img(vgg_fea):
mid_feas = vgg_fea.data
mid_feas = torch.transpose(mid_feas, 0, 1)
fea_nrow = round((mid_feas.shape[0]) ** 0.5)
fea_grid = tv_utils.make_grid(mid_feas, nrow=fea_nrow, normalize=True, scale_each=True)
fea_grid = fea_grid.cpu().float().numpy().transpose((1, 2, 0))
fea_grid = (fea_grid * 255.0).round().clip(0, 255).astype(np.uint8)
return fea_grid
out_dir = osp.splitext(inp)[0]
if not osp.exists(out_dir):
os.mkdir(out_dir)
bind = 0
for layer_name, mid_feas in feas.items():
out_fea = vgg_fea2img(mid_feas[bind:(bind + 1)])
# out_fea = cv2.applyColorMap(out_fea[:, :, 0], cv2.COLORMAP_JET)
out_path = osp.join(out_dir, '{}_{}.png'.format(bind, layer_name))
cv2.imwrite(out_path, out_fea)