Spaces:
Build error
Build error
""" | |
Taken from ESPNet | |
""" | |
import torch | |
class PostNet(torch.nn.Module): | |
""" | |
From Tacotron2 | |
Postnet module for Spectrogram prediction network. | |
This is a module of Postnet in Spectrogram prediction network, | |
which described in `Natural TTS Synthesis by | |
Conditioning WaveNet on Mel Spectrogram Predictions`_. | |
The Postnet refines the predicted | |
Mel-filterbank of the decoder, | |
which helps to compensate the detail sturcture of spectrogram. | |
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: | |
https://arxiv.org/abs/1712.05884 | |
""" | |
def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True): | |
""" | |
Initialize postnet module. | |
Args: | |
idim (int): Dimension of the inputs. | |
odim (int): Dimension of the outputs. | |
n_layers (int, optional): The number of layers. | |
n_filts (int, optional): The number of filter size. | |
n_units (int, optional): The number of filter channels. | |
use_batch_norm (bool, optional): Whether to use batch normalization.. | |
dropout_rate (float, optional): Dropout rate.. | |
""" | |
super(PostNet, self).__init__() | |
self.postnet = torch.nn.ModuleList() | |
for layer in range(n_layers - 1): | |
ichans = odim if layer == 0 else n_chans | |
ochans = odim if layer == n_layers - 1 else n_chans | |
if use_batch_norm: | |
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), | |
torch.nn.GroupNorm(num_groups=32, num_channels=ochans), torch.nn.Tanh(), | |
torch.nn.Dropout(dropout_rate), )] | |
else: | |
self.postnet += [ | |
torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(), | |
torch.nn.Dropout(dropout_rate), )] | |
ichans = n_chans if n_layers != 1 else odim | |
if use_batch_norm: | |
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), | |
torch.nn.GroupNorm(num_groups=20, num_channels=odim), | |
torch.nn.Dropout(dropout_rate), )] | |
else: | |
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), | |
torch.nn.Dropout(dropout_rate), )] | |
def forward(self, xs): | |
""" | |
Calculate forward propagation. | |
Args: | |
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). | |
Returns: | |
Tensor: Batch of padded output tensor. (B, odim, Tmax). | |
""" | |
for i in range(len(self.postnet)): | |
xs = self.postnet[i](xs) | |
return xs | |