Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,464 Bytes
a64b7d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import spectral_norm
from basicsr.utils.registry import ARCH_REGISTRY
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
from .vgg_arch import VGGFeatureExtractor
class SFTUpBlock(nn.Module):
"""Spatial feature transform (SFT) with upsampling block.
Args:
in_channel (int): Number of input channels.
out_channel (int): Number of output channels.
kernel_size (int): Kernel size in convolutions. Default: 3.
padding (int): Padding in convolutions. Default: 1.
"""
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
super(SFTUpBlock, self).__init__()
self.conv1 = nn.Sequential(
Blur(in_channel),
spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.04, True),
# The official codes use two LeakyReLU here, so 0.04 for equivalent
)
self.convup = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
nn.LeakyReLU(0.2, True),
)
# for SFT scale and shift
self.scale_block = nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
self.shift_block = nn.Sequential(
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
# The official codes use sigmoid for shift block, do not know why
def forward(self, x, updated_feat):
out = self.conv1(x)
# SFT
scale = self.scale_block(updated_feat)
shift = self.shift_block(updated_feat)
out = out * scale + shift
# upsample
out = self.convup(out)
return out
@ARCH_REGISTRY.register()
class DFDNet(nn.Module):
"""DFDNet: Deep Face Dictionary Network.
It only processes faces with 512x512 size.
Args:
num_feat (int): Number of feature channels.
dict_path (str): Path to the facial component dictionary.
"""
def __init__(self, num_feat, dict_path):
super().__init__()
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
# part_sizes: [80, 80, 50, 110]
channel_sizes = [128, 256, 512, 512]
self.feature_sizes = np.array([256, 128, 64, 32])
self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
self.flag_dict_device = False
# dict
self.dict = torch.load(dict_path)
# vgg face extractor
self.vgg_extractor = VGGFeatureExtractor(
layer_name_list=self.vgg_layers,
vgg_type='vgg19',
use_input_norm=True,
range_norm=True,
requires_grad=False)
# attention block for fusing dictionary features and input features
self.attn_blocks = nn.ModuleDict()
for idx, feat_size in enumerate(self.feature_sizes):
for name in self.parts:
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
# multi scale dilation block
self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
# upsampling and reconstruction
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
self.upsample4 = nn.Sequential(
spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
"""swap the features from the dictionary."""
# get the original vgg features
part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
# resize original vgg features
part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
# use adaptive instance normalization to adjust color and illuminations
dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
# get similarity scores
similarity_score = F.conv2d(part_resize_feat, dict_feat)
similarity_score = F.softmax(similarity_score.view(-1), dim=0)
# select the most similar features in the dict (after norm)
select_idx = torch.argmax(similarity_score)
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
# attention
attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
attn_feat = attn * swap_feat
# update features
updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
return updated_feat
def put_dict_to_device(self, x):
if self.flag_dict_device is False:
for k, v in self.dict.items():
for kk, vv in v.items():
self.dict[k][kk] = vv.to(x)
self.flag_dict_device = True
def forward(self, x, part_locations):
"""
Now only support testing with batch size = 0.
Args:
x (Tensor): Input faces with shape (b, c, 512, 512).
part_locations (list[Tensor]): Part locations.
"""
self.put_dict_to_device(x)
# extract vggface features
vgg_features = self.vgg_extractor(x)
# update vggface features using the dictionary for each part
updated_vgg_features = []
batch = 0 # only supports testing with batch size = 0
for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
dict_features = self.dict[f'{f_size}']
vgg_feat = vgg_features[vgg_layer]
updated_feat = vgg_feat.clone()
# swap features from dictionary
for part_idx, part_name in enumerate(self.parts):
location = (part_locations[part_idx][batch] // (512 / f_size)).int()
updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
f_size)
updated_vgg_features.append(updated_feat)
vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
# use updated vgg features to modulate the upsampled features with
# SFT (Spatial Feature Transform) scaling and shifting manner.
upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
out = self.upsample4(upsampled_feat)
return out
|