one-shot-talking-face / models /layers /activation_norm.py
DmitrMakeev's picture
Upload 8 files
add457a
from types import SimpleNamespace
import torch
try:
# from torch.nn import BatchNorm2d as SyncBatchNorm
from torch.nn import SyncBatchNorm
except ImportError:
from torch.nn import BatchNorm2d as SyncBatchNorm
from torch import nn
from torch.nn import functional as F
from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock
from .misc import PartialSequential
import sync_batchnorm
class AdaptiveNorm(nn.Module):
r"""Adaptive normalization layer. The layer first normalizes the input, then
performs an affine transformation using parameters computed from the
conditional inputs.
Args:
num_features (int): Number of channels in the input tensor.
cond_dims (int): Number of channels in the conditional inputs.
weight_norm_type (str): Type of weight normalization.
``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``.
projection (bool): If ``True``, project the conditional input to gamma
and beta using a fully connected layer, otherwise directly use
the conditional input as gamma and beta.
separate_projection (bool): If ``True``, we will use two different
layers for gamma and beta. Otherwise, we will use one layer. It
matters only if you apply any weight norms to this layer.
input_dim (int): Number of dimensions of the input tensor.
activation_norm_type (str):
Type of activation normalization.
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
activation_norm_params (obj, optional, default=None):
Parameters of activation normalization.
If not ``None``, ``activation_norm_params.__dict__`` will be used as
keyword arguments when initializing activation normalization.
"""
def __init__(self, num_features, cond_dims, weight_norm_type='',
projection=True,
separate_projection=False,
input_dim=2,
activation_norm_type='instance',
activation_norm_params=None):
super().__init__()
self.projection = projection
self.separate_projection = separate_projection
if activation_norm_params is None:
activation_norm_params = SimpleNamespace(affine=False)
self.norm = get_activation_norm_layer(num_features,
activation_norm_type,
input_dim,
**vars(activation_norm_params))
if self.projection:
if self.separate_projection:
self.fc_gamma = \
LinearBlock(cond_dims, num_features,
weight_norm_type=weight_norm_type)
self.fc_beta = \
LinearBlock(cond_dims, num_features,
weight_norm_type=weight_norm_type)
else:
self.fc = LinearBlock(cond_dims, num_features * 2,
weight_norm_type=weight_norm_type)
self.conditional = True
def forward(self, x, y, **kwargs):
r"""Adaptive Normalization forward.
Args:
x (N x C1 x * tensor): Input tensor.
y (N x C2 tensor): Conditional information.
Returns:
out (N x C1 x * tensor): Output tensor.
"""
if self.projection:
if self.separate_projection:
gamma = self.fc_gamma(y)
beta = self.fc_beta(y)
for _ in range(x.dim() - gamma.dim()):
gamma = gamma.unsqueeze(-1)
beta = beta.unsqueeze(-1)
else:
y = self.fc(y)
for _ in range(x.dim() - y.dim()):
y = y.unsqueeze(-1)
gamma, beta = y.chunk(2, 1)
else:
for _ in range(x.dim() - y.dim()):
y = y.unsqueeze(-1)
gamma, beta = y.chunk(2, 1)
x = self.norm(x) if self.norm is not None else x
out = x * (1 + gamma) + beta
return out
class SpatiallyAdaptiveNorm(nn.Module):
r"""Spatially Adaptive Normalization (SPADE) initialization.
Args:
num_features (int) : Number of channels in the input tensor.
cond_dims (int or list of int) : List of numbers of channels
in the input.
num_filters (int): Number of filters in SPADE.
kernel_size (int): Kernel size of the convolutional filters in
the SPADE layer.
weight_norm_type (str): Type of weight normalization.
``'none'``, ``'spectral'``, or ``'weight'``.
separate_projection (bool): If ``True``, we will use two different
layers for gamma and beta. Otherwise, we will use one layer. It
matters only if you apply any weight norms to this layer.
activation_norm_type (str):
Type of activation normalization.
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
``'layer'``, ``'layer_2d'``, ``'group'``.
activation_norm_params (obj, optional, default=None):
Parameters of activation normalization.
If not ``None``, ``activation_norm_params.__dict__`` will be used as
keyword arguments when initializing activation normalization.
"""
def __init__(self,
num_features,
cond_dims,
num_filters=128,
kernel_size=3,
weight_norm_type='',
separate_projection=False,
activation_norm_type='sync_batch',
activation_norm_params=None,
partial=False):
super().__init__()
if activation_norm_params is None:
activation_norm_params = SimpleNamespace(affine=False)
padding = kernel_size // 2
self.separate_projection = separate_projection
self.mlps = nn.ModuleList()
self.gammas = nn.ModuleList()
self.betas = nn.ModuleList()
# Make cond_dims a list.
if type(cond_dims) != list:
cond_dims = [cond_dims]
# Make num_filters a list.
if not isinstance(num_filters, list):
num_filters = [num_filters] * len(cond_dims)
else:
assert len(num_filters) >= len(cond_dims)
# Make partial a list.
if not isinstance(partial, list):
partial = [partial] * len(cond_dims)
else:
assert len(partial) >= len(cond_dims)
for i, cond_dim in enumerate(cond_dims):
mlp = []
conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock
sequential = PartialSequential if partial[i] else nn.Sequential
if num_filters[i] > 0:
mlp += [conv_block(cond_dim,
num_filters[i],
kernel_size,
padding=padding,
weight_norm_type=weight_norm_type,
nonlinearity='relu')]
mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i]
if self.separate_projection:
if partial[i]:
raise NotImplementedError(
'Separate projection not yet implemented for ' +
'partial conv')
self.mlps.append(nn.Sequential(*mlp))
self.gammas.append(
conv_block(mlp_ch, num_features,
kernel_size,
padding=padding,
weight_norm_type=weight_norm_type))
self.betas.append(
conv_block(mlp_ch, num_features,
kernel_size,
padding=padding,
weight_norm_type=weight_norm_type))
else:
mlp += [conv_block(mlp_ch, num_features * 2, kernel_size,
padding=padding,
weight_norm_type=weight_norm_type)]
self.mlps.append(sequential(*mlp))
self.norm = get_activation_norm_layer(num_features,
activation_norm_type,
2,
**vars(activation_norm_params))
self.conditional = True
def forward(self, x, *cond_inputs, **kwargs):
r"""Spatially Adaptive Normalization (SPADE) forward.
Args:
x (N x C1 x H x W tensor) : Input tensor.
cond_inputs (list of tensors) : Conditional maps for SPADE.
Returns:
output (4D tensor) : Output tensor.
"""
output = self.norm(x) if self.norm is not None else x
for i in range(len(cond_inputs)):
if cond_inputs[i] is None:
continue
label_map = F.interpolate(cond_inputs[i], size=x.size()[2:],
mode='nearest')
if self.separate_projection:
hidden = self.mlps[i](label_map)
gamma = self.gammas[i](hidden)
beta = self.betas[i](hidden)
else:
affine_params = self.mlps[i](label_map)
gamma, beta = affine_params.chunk(2, dim=1)
output = output * (1 + gamma) + beta
return output
class HyperSpatiallyAdaptiveNorm(nn.Module):
r"""Spatially Adaptive Normalization (SPADE) initialization.
Args:
num_features (int) : Number of channels in the input tensor.
cond_dims (int or list of int) : List of numbers of channels
in the conditional input.
num_filters (int): Number of filters in SPADE.
kernel_size (int): Kernel size of the convolutional filters in
the SPADE layer.
weight_norm_type (str): Type of weight normalization.
``'none'``, ``'spectral'``, or ``'weight'``.
activation_norm_type (str):
Type of activation normalization.
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
``'layer'``, ``'layer_2d'``, ``'group'``.
is_hyper (bool): Whether to use hyper SPADE.
"""
def __init__(self, num_features, cond_dims,
num_filters=0, kernel_size=3,
weight_norm_type='',
activation_norm_type='sync_batch', is_hyper=True):
super().__init__()
padding = kernel_size // 2
self.mlps = nn.ModuleList()
if type(cond_dims) != list:
cond_dims = [cond_dims]
for i, cond_dim in enumerate(cond_dims):
mlp = []
if not is_hyper or (i != 0):
if num_filters > 0:
mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size,
padding=padding,
weight_norm_type=weight_norm_type,
nonlinearity='relu')]
mlp_ch = cond_dim if num_filters == 0 else num_filters
mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size,
padding=padding,
weight_norm_type=weight_norm_type)]
mlp = nn.Sequential(*mlp)
else:
if num_filters > 0:
raise ValueError('Multi hyper layer not supported yet.')
mlp = HyperConv2d(padding=padding)
self.mlps.append(mlp)
self.norm = get_activation_norm_layer(num_features,
activation_norm_type,
2,
affine=False)
self.conditional = True
def forward(self, x, *cond_inputs,
norm_weights=(None, None), **kwargs):
r"""Spatially Adaptive Normalization (SPADE) forward.
Args:
x (4D tensor) : Input tensor.
cond_inputs (list of tensors) : Conditional maps for SPADE.
norm_weights (5D tensor or list of tensors): conv weights or
[weights, biases].
Returns:
output (4D tensor) : Output tensor.
"""
output = self.norm(x)
for i in range(len(cond_inputs)):
if cond_inputs[i] is None:
continue
if type(cond_inputs[i]) == list:
cond_input, mask = cond_inputs[i]
mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear',
align_corners=False)
else:
cond_input = cond_inputs[i]
mask = None
label_map = F.interpolate(cond_input, size=x.size()[2:])
if norm_weights is None or norm_weights[0] is None or i != 0:
affine_params = self.mlps[i](label_map)
else:
affine_params = self.mlps[i](label_map,
conv_weights=norm_weights)
gamma, beta = affine_params.chunk(2, dim=1)
if mask is not None:
gamma = gamma * (1 - mask)
beta = beta * (1 - mask)
output = output * (1 + gamma) + beta
return output
class LayerNorm2d(nn.Module):
r"""Layer Normalization as introduced in
https://arxiv.org/abs/1607.06450.
This is the usual way to apply layer normalization in CNNs.
Note that unlike the pytorch implementation which applies per-element
scale and bias, here it applies per-channel scale and bias, similar to
batch/instance normalization.
Args:
num_features (int): Number of channels in the input tensor.
eps (float, optional, default=1e-5): a value added to the
denominator for numerical stability.
affine (bool, optional, default=False): If ``True``, performs
affine transformation after normalization.
"""
def __init__(self, num_features, eps=1e-5, affine=True):
super(LayerNorm2d, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
r"""
Args:
x (tensor): Input tensor.
"""
shape = [-1] + [1] * (x.dim() - 1)
if x.size(0) == 1:
mean = x.view(-1).mean().view(*shape)
std = x.view(-1).std().view(*shape)
else:
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
x = (x - mean) / (std + self.eps)
if self.affine:
shape = [1, -1] + [1] * (x.dim() - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
def get_activation_norm_layer(num_features, norm_type,
input_dim, **norm_params):
r"""Return an activation normalization layer.
Args:
num_features (int): Number of feature channels.
norm_type (str):
Type of activation normalization.
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
input_dim (int): Number of input dimensions.
norm_params: Arbitrary keyword arguments that will be used to
initialize the activation normalization.
"""
input_dim = max(input_dim, 1) # Norm1d works with both 0d and 1d inputs
if norm_type == 'none' or norm_type == '':
norm_layer = None
elif norm_type == 'batch':
# norm = getattr(nn, 'BatchNorm%dd' % input_dim)
norm = getattr(sync_batchnorm, 'SynchronizedBatchNorm%dd' % input_dim)
norm_layer = norm(num_features, **norm_params)
elif norm_type == 'instance':
affine = norm_params.pop('affine', True) # Use affine=True by default
norm = getattr(nn, 'InstanceNorm%dd' % input_dim)
norm_layer = norm(num_features, affine=affine, **norm_params)
elif norm_type == 'sync_batch':
# There is a bug of using amp O1 with synchronize batch norm.
# The lines below fix it.
affine = norm_params.pop('affine', True)
# Always call SyncBN with affine=True
norm_layer = SyncBatchNorm(num_features, affine=True, **norm_params)
norm_layer.weight.requires_grad = affine
norm_layer.bias.requires_grad = affine
elif norm_type == 'layer':
norm_layer = nn.LayerNorm(num_features, **norm_params)
elif norm_type == 'layer_2d':
norm_layer = LayerNorm2d(num_features, **norm_params)
elif norm_type == 'group':
norm_layer = nn.GroupNorm(num_channels=num_features, **norm_params)
elif norm_type == 'adaptive':
norm_layer = AdaptiveNorm(num_features, **norm_params)
elif norm_type == 'spatially_adaptive':
if input_dim != 2:
raise ValueError('Spatially adaptive normalization layers '
'only supports 2D input')
norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params)
elif norm_type == 'hyper_spatially_adaptive':
if input_dim != 2:
raise ValueError('Spatially adaptive normalization layers '
'only supports 2D input')
norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params)
else:
raise ValueError('Activation norm layer %s '
'is not recognized' % norm_type)
return norm_layer