# -*- coding: utf-8 -*- # Copyright 2019 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) """Parallel WaveGAN Modules.""" import logging import math import torch from torch import nn from modules.vocoder.parallel_wavegan.layers import Conv1d from modules.vocoder.parallel_wavegan.layers import Conv1d1x1 from modules.vocoder.parallel_wavegan.layers import ResidualBlock from modules.vocoder.parallel_wavegan.layers import upsample from modules.vocoder.parallel_wavegan import models from modules.vocoder.parallel_wavegan.models import SourceModuleCycNoise_v1 from utils.commons.hparams import hparams import numpy as np class ParallelWaveGANGenerator(torch.nn.Module): """Parallel WaveGAN Generator module.""" def __init__(self, in_channels=1, out_channels=1, kernel_size=3, layers=30, stacks=3, residual_channels=64, gate_channels=128, skip_channels=64, aux_channels=80, aux_context_window=2, dropout=0.0, bias=True, use_weight_norm=True, use_causal_conv=False, upsample_conditional_features=True, upsample_net="ConvInUpsampleNetwork", upsample_params={"upsample_scales": [4, 4, 4, 4]}, use_pitch_embed=False, use_nsf=False, sample_rate=22050, ): """Initialize Parallel WaveGAN Generator module. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_size (int): Kernel size of dilated convolution. layers (int): Number of residual block layers. stacks (int): Number of stacks i.e., dilation cycles. residual_channels (int): Number of channels in residual conv. gate_channels (int): Number of channels in gated conv. skip_channels (int): Number of channels in skip conv. aux_channels (int): Number of channels for auxiliary feature conv. aux_context_window (int): Context window size for auxiliary feature. dropout (float): Dropout rate. 0.0 means no dropout applied. bias (bool): Whether to use bias parameter in conv layer. use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. use_causal_conv (bool): Whether to use causal structure. upsample_conditional_features (bool): Whether to use upsampling network. upsample_net (str): Upsampling network architecture. upsample_params (dict): Upsampling network parameters. """ super(ParallelWaveGANGenerator, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.aux_channels = aux_channels self.layers = layers self.stacks = stacks self.kernel_size = kernel_size # check the number of layers and stacks assert layers % stacks == 0 layers_per_stack = layers // stacks # define first convolution self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) # define conv + upsampling network self.aux_context_window = aux_context_window if upsample_conditional_features: upsample_params.update({ "use_causal_conv": use_causal_conv, }) if upsample_net == "MelGANGenerator": assert aux_context_window == 0 upsample_params.update({ "use_weight_norm": False, # not to apply twice "use_final_nonlinear_activation": False, }) self.upsample_net = getattr(models, upsample_net)(**upsample_params) else: if upsample_net == "ConvInUpsampleNetwork": upsample_params.update({ "aux_channels": aux_channels, "aux_context_window": aux_context_window, }) self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) else: self.upsample_net = None # define residual blocks self.conv_layers = torch.nn.ModuleList() for layer in range(layers): dilation = 2 ** (layer % layers_per_stack) conv = ResidualBlock( kernel_size=kernel_size, residual_channels=residual_channels, gate_channels=gate_channels, skip_channels=skip_channels, aux_channels=aux_channels, dilation=dilation, dropout=dropout, bias=bias, use_causal_conv=use_causal_conv, ) self.conv_layers += [conv] # define output layers self.last_conv_layers = torch.nn.ModuleList([ torch.nn.ReLU(inplace=True), Conv1d1x1(skip_channels, skip_channels, bias=True), torch.nn.ReLU(inplace=True), Conv1d1x1(skip_channels, out_channels, bias=True), ]) self.use_pitch_embed = use_pitch_embed if use_pitch_embed: self.pitch_embed = nn.Embedding(300, aux_channels, 0) self.c_proj = nn.Linear(2 * aux_channels, aux_channels) self.use_nsf = use_nsf if use_nsf: self.harmonic_num = 8 hop_size = np.prod(upsample_params['upsample_scales']) self.f0_upsamp = torch.nn.Upsample(scale_factor=hop_size) self.m_source = SourceModuleCycNoise_v1(sample_rate, 0.003) self.nsf_conv = nn.Sequential(nn.Conv1d(1, aux_channels, 1), torch.nn.Tanh()) # apply weight norm if use_weight_norm: self.apply_weight_norm() def forward(self, x, c=None, pitch=None, f0=None, **kwargs): """Calculate forward propagation. Args: x (Tensor): Input noise signal (B, C_in, T). c (Tensor): Local conditioning auxiliary features (B, C ,T'). pitch (Tensor): Local conditioning pitch (B, T'). Returns: Tensor: Output tensor (B, C_out, T) """ # perform upsampling if c is not None and self.upsample_net is not None: if self.use_pitch_embed: p = self.pitch_embed(pitch) c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2) c = self.upsample_net(c) if self.use_nsf: f0_upsample = self.f0_upsamp( f0[:, None, :][:, :, self.aux_context_window:-self.aux_context_window]) f0_upsample = self.nsf_conv(f0_upsample) c = c + f0_upsample if x is None: x = torch.randn([c.size(0), 1, c.size(-1)]).to(c.device) assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1)) # encode to hidden representation x = self.first_conv(x) skips = 0 for f in self.conv_layers: x, h = f(x, c) skips += h skips *= math.sqrt(1.0 / len(self.conv_layers)) # apply final layers x = skips for f in self.last_conv_layers: x = f(x) return x def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: logging.debug(f"Weight norm is removed from {m}.") torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): torch.nn.utils.weight_norm(m) logging.debug(f"Weight norm is applied to {m}.") self.apply(_apply_weight_norm) @staticmethod def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] return (kernel_size - 1) * sum(dilations) + 1 @property def receptive_field_size(self): """Return receptive field size.""" return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) class ParallelWaveGANDiscriminator(torch.nn.Module): """Parallel WaveGAN Discriminator module.""" def __init__(self, in_channels=1, out_channels=1, kernel_size=3, layers=10, conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, bias=True, use_weight_norm=True, ): """Initialize Parallel WaveGAN Discriminator module. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_size (int): Number of output channels. layers (int): Number of conv layers. conv_channels (int): Number of chnn layers. dilation_factor (int): Dilation factor. For example, if dilation_factor = 2, the dilation will be 2, 4, 8, ..., and so on. nonlinear_activation (str): Nonlinear function after each conv. nonlinear_activation_params (dict): Nonlinear function parameters bias (bool): Whether to use bias parameter in conv. use_weight_norm (bool) Whether to use weight norm. If set to true, it will be applied to all of the conv layers. """ super(ParallelWaveGANDiscriminator, self).__init__() assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." assert dilation_factor > 0, "Dilation factor must be > 0." self.conv_layers = torch.nn.ModuleList() conv_in_channels = in_channels for i in range(layers - 1): if i == 0: dilation = 1 else: dilation = i if dilation_factor == 1 else dilation_factor ** i conv_in_channels = conv_channels padding = (kernel_size - 1) // 2 * dilation conv_layer = [ Conv1d(conv_in_channels, conv_channels, kernel_size=kernel_size, padding=padding, dilation=dilation, bias=bias), getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params) ] self.conv_layers += conv_layer padding = (kernel_size - 1) // 2 last_conv_layer = Conv1d( conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) self.conv_layers += [last_conv_layer] # apply weight norm if use_weight_norm: self.apply_weight_norm() def forward(self, x, cond=None): """Calculate forward propagation. Args: x (Tensor): Input noise signal (B, 1, T). cond (Tensor): Input noise signal (B, H, T_frame). Returns: Tensor: Output tensor (B, 1, T) """ cond_layer_i = len(self.conv_layers) // 2 for i, f in enumerate(self.conv_layers): if i == cond_layer_i and cond is not None: aux_context_window = hparams['aux_context_window'] cond = cond[:, :, aux_context_window:-aux_context_window] cond = cond[:, :, :, None].repeat([1, 1, 1, hparams['hop_size']]).reshape( cond.shape[0], cond.shape[1], -1) x = x + cond x = f(x) return x def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): torch.nn.utils.weight_norm(m) logging.debug(f"Weight norm is applied to {m}.") self.apply(_apply_weight_norm) def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: logging.debug(f"Weight norm is removed from {m}.") torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) class ResidualParallelWaveGANDiscriminator(torch.nn.Module): """Parallel WaveGAN Discriminator module.""" def __init__(self, in_channels=1, out_channels=1, kernel_size=3, layers=30, stacks=3, residual_channels=64, gate_channels=128, skip_channels=64, dropout=0.0, bias=True, use_weight_norm=True, use_causal_conv=False, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, ): """Initialize Parallel WaveGAN Discriminator module. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_size (int): Kernel size of dilated convolution. layers (int): Number of residual block layers. stacks (int): Number of stacks i.e., dilation cycles. residual_channels (int): Number of channels in residual conv. gate_channels (int): Number of channels in gated conv. skip_channels (int): Number of channels in skip conv. dropout (float): Dropout rate. 0.0 means no dropout applied. bias (bool): Whether to use bias parameter in conv. use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. use_causal_conv (bool): Whether to use causal structure. nonlinear_activation_params (dict): Nonlinear function parameters """ super(ResidualParallelWaveGANDiscriminator, self).__init__() assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." self.in_channels = in_channels self.out_channels = out_channels self.layers = layers self.stacks = stacks self.kernel_size = kernel_size # check the number of layers and stacks assert layers % stacks == 0 layers_per_stack = layers // stacks # define first convolution self.first_conv = torch.nn.Sequential( Conv1d1x1(in_channels, residual_channels, bias=True), getattr(torch.nn, nonlinear_activation)( inplace=True, **nonlinear_activation_params), ) # define residual blocks self.conv_layers = torch.nn.ModuleList() for layer in range(layers): dilation = 2 ** (layer % layers_per_stack) conv = ResidualBlock( kernel_size=kernel_size, residual_channels=residual_channels, gate_channels=gate_channels, skip_channels=skip_channels, aux_channels=-1, dilation=dilation, dropout=dropout, bias=bias, use_causal_conv=use_causal_conv, ) self.conv_layers += [conv] # define output layers self.last_conv_layers = torch.nn.ModuleList([ getattr(torch.nn, nonlinear_activation)( inplace=True, **nonlinear_activation_params), Conv1d1x1(skip_channels, skip_channels, bias=True), getattr(torch.nn, nonlinear_activation)( inplace=True, **nonlinear_activation_params), Conv1d1x1(skip_channels, out_channels, bias=True), ]) # apply weight norm if use_weight_norm: self.apply_weight_norm() def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Input noise signal (B, 1, T). Returns: Tensor: Output tensor (B, 1, T) """ x = self.first_conv(x) skips = 0 for f in self.conv_layers: x, h = f(x, None) skips += h skips *= math.sqrt(1.0 / len(self.conv_layers)) # apply final layers x = skips for f in self.last_conv_layers: x = f(x) return x def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): torch.nn.utils.weight_norm(m) logging.debug(f"Weight norm is applied to {m}.") self.apply(_apply_weight_norm) def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: logging.debug(f"Weight norm is removed from {m}.") torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm)