from typing import Optional import torch from torch import nn from torch.nn import functional as F from .utils import activate_add_tanh_sigmoid_multiply class LayerNorm(nn.Module): def __init__(self, channels: int, eps: float = 1e-5): super(LayerNorm, self).__init__() self.channels = channels self.eps = eps self.gamma = nn.Parameter(torch.ones(channels)) self.beta = nn.Parameter(torch.zeros(channels)) def forward(self, x: torch.Tensor): x = x.transpose(1, -1) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) return x.transpose(1, -1) class WN(torch.nn.Module): def __init__( self, hidden_channels: int, kernel_size: int, dilation_rate: int, n_layers: int, gin_channels: int = 0, p_dropout: int = 0, ): super(WN, self).__init__() assert kernel_size % 2 == 1 self.hidden_channels = hidden_channels self.kernel_size = (kernel_size,) self.dilation_rate = dilation_rate self.n_layers = n_layers self.gin_channels = gin_channels self.p_dropout = float(p_dropout) self.in_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList() self.drop = nn.Dropout(float(p_dropout)) if gin_channels != 0: cond_layer = torch.nn.Conv1d( gin_channels, 2 * hidden_channels * n_layers, 1 ) self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") for i in range(n_layers): dilation = dilation_rate**i padding = int((kernel_size * dilation - dilation) / 2) in_layer = torch.nn.Conv1d( hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding, ) in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) # last one is not necessary if i < n_layers - 1: res_skip_channels = 2 * hidden_channels else: res_skip_channels = hidden_channels res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def __call__( self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, ) -> torch.Tensor: return super().__call__(x, x_mask, g=g) def forward( self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, ) -> torch.Tensor: output = torch.zeros_like(x) if g is not None: g = self.cond_layer(g) for i, (in_layer, res_skip_layer) in enumerate( zip(self.in_layers, self.res_skip_layers) ): x_in: torch.Tensor = in_layer(x) if g is not None: cond_offset = i * 2 * self.hidden_channels g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] else: g_l = torch.zeros_like(x_in) acts = activate_add_tanh_sigmoid_multiply(x_in, g_l, self.hidden_channels) acts: torch.Tensor = self.drop(acts) res_skip_acts: torch.Tensor = res_skip_layer(acts) if i < self.n_layers - 1: res_acts = res_skip_acts[:, : self.hidden_channels, :] x = (x + res_acts) * x_mask output = output + res_skip_acts[:, self.hidden_channels :, :] else: output = output + res_skip_acts return output * x_mask def remove_weight_norm(self): if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer) for l in self.in_layers: torch.nn.utils.remove_weight_norm(l) for l in self.res_skip_layers: torch.nn.utils.remove_weight_norm(l) def __prepare_scriptable__(self): if self.gin_channels != 0: for hook in self.cond_layer._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(self.cond_layer) for l in self.in_layers: for hook in l._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(l) for l in self.res_skip_layers: for hook in l._forward_pre_hooks.values(): if ( hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm" ): torch.nn.utils.remove_weight_norm(l) return self