|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import numbers |
|
from typing import Any, List, Tuple, Union |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from torch.nn import functional as F |
|
|
|
from modules.general.scaling import ActivationBalancer |
|
from modules.general.scaling import BasicNorm as _BasicNorm |
|
|
|
|
|
_shape_t = Union[int, List[int], torch.Size] |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
__constants__ = ["normalized_shape", "eps", "elementwise_affine"] |
|
normalized_shape: Tuple[int, ...] |
|
eps: float |
|
elementwise_affine: bool |
|
|
|
def __init__( |
|
self, |
|
normalized_shape: _shape_t, |
|
eps: float = 1e-5, |
|
elementwise_affine: bool = True, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super(LayerNorm, self).__init__() |
|
if isinstance(normalized_shape, numbers.Integral): |
|
normalized_shape = (normalized_shape,) |
|
self.normalized_shape = tuple(normalized_shape) |
|
self.eps = eps |
|
self.elementwise_affine = elementwise_affine |
|
if self.elementwise_affine: |
|
self.weight = nn.Parameter( |
|
torch.empty(self.normalized_shape, **factory_kwargs) |
|
) |
|
self.bias = nn.Parameter( |
|
torch.empty(self.normalized_shape, **factory_kwargs) |
|
) |
|
else: |
|
self.register_parameter("weight", None) |
|
self.register_parameter("bias", None) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
if self.elementwise_affine: |
|
nn.init.ones_(self.weight) |
|
nn.init.zeros_(self.bias) |
|
|
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
|
if isinstance(input, tuple): |
|
input, embedding = input |
|
output = F.layer_norm( |
|
input, self.normalized_shape, self.weight, self.bias, self.eps |
|
) |
|
return output, embedding |
|
|
|
assert embedding is None |
|
return F.layer_norm( |
|
input, self.normalized_shape, self.weight, self.bias, self.eps |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
return ( |
|
"{normalized_shape}, eps={eps}, " |
|
"elementwise_affine={elementwise_affine}".format(**self.__dict__) |
|
) |
|
|
|
|
|
class AdaptiveLayerNorm(nn.Module): |
|
r"""Adaptive Layer Normalization""" |
|
|
|
def __init__(self, d_model, norm) -> None: |
|
super(AdaptiveLayerNorm, self).__init__() |
|
self.project_layer = nn.Linear(d_model, 2 * d_model) |
|
self.norm = norm |
|
self.d_model = d_model |
|
self.eps = self.norm.eps |
|
|
|
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: |
|
if isinstance(input, tuple): |
|
input, embedding = input |
|
weight, bias = torch.split( |
|
self.project_layer(embedding), |
|
split_size_or_sections=self.d_model, |
|
dim=-1, |
|
) |
|
return (weight * self.norm(input) + bias, embedding) |
|
|
|
weight, bias = torch.split( |
|
self.project_layer(embedding), |
|
split_size_or_sections=self.d_model, |
|
dim=-1, |
|
) |
|
return weight * self.norm(input) + bias |
|
|
|
|
|
class BasicNorm(_BasicNorm): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
eps: float = 1e-5, |
|
device=None, |
|
dtype=None, |
|
): |
|
super(BasicNorm, self).__init__(d_model, eps=eps) |
|
|
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
|
if isinstance(input, tuple): |
|
input, embedding = input |
|
return ( |
|
super(BasicNorm, self).forward(input), |
|
embedding, |
|
) |
|
|
|
assert embedding is None |
|
return super(BasicNorm, self).forward(input) |
|
|
|
|
|
class BalancedBasicNorm(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
eps: float = 1e-5, |
|
device=None, |
|
dtype=None, |
|
): |
|
super(BalancedBasicNorm, self).__init__() |
|
self.balancer = ActivationBalancer( |
|
d_model, |
|
channel_dim=-1, |
|
min_positive=0.45, |
|
max_positive=0.55, |
|
max_abs=6.0, |
|
) |
|
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) |
|
|
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
|
if isinstance(input, tuple): |
|
input, embedding = input |
|
return self.norm((self.balancer(input), embedding)) |
|
|
|
assert embedding is None |
|
return self.norm(self.balancer(input)) |
|
|
|
|
|
class IdentityNorm(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
eps: float = 1e-5, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
super(IdentityNorm, self).__init__() |
|
|
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
|
if isinstance(input, tuple): |
|
return input |
|
|
|
assert embedding is None |
|
return input |
|
|