|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch.nn.functional as F |
|
|
|
|
|
def dynamic_batch_collate(batch): |
|
""" |
|
Collates batches dynamically based on the length of sequences within each batch. |
|
This function ensures that each batch contains sequences of similar lengths, |
|
optimizing padding and computational efficiency. |
|
|
|
Args: |
|
batch: A list of dictionaries, each containing 'id', 'phoneme_seq_encoded', |
|
'mel_spectrogram', 'mel_length', 'stop_token_targets'. |
|
|
|
Returns: |
|
A batch of sequences where sequences are padded to match the longest sequence in the batch. |
|
""" |
|
|
|
batch.sort(key=lambda x: x['mel_lengths'], reverse=True) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
ids = [item['id'] for item in batch] |
|
phoneme_seqs = [item['phoneme_seq_encoded'] for item in batch] |
|
mel_specs = [item['mel_spec'] for item in batch] |
|
|
|
|
|
mel_lengths = torch.tensor([item['mel_lengths'] for item in batch], device=device) |
|
stop_token_targets = [item['stop_token_targets'] for item in batch] |
|
|
|
|
|
phoneme_seq_padded = torch.nn.utils.rnn.pad_sequence(phoneme_seqs, batch_first=True, padding_value=0).to(device) |
|
|
|
|
|
max_len = max(mel_lengths).item() |
|
num_mel_bins = 80 |
|
|
|
mel_specs_padded = torch.zeros((len(mel_specs), num_mel_bins, max_len), device=device) |
|
for i, mel in enumerate(mel_specs): |
|
mel_len = mel.shape[1] |
|
mel_specs_padded[i, :, :mel_len] = mel.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stop_token_targets_padded = torch.zeros((len(stop_token_targets), max_len), device=device) |
|
for i, stop in enumerate(stop_token_targets): |
|
stop_len = stop.size(0) |
|
stop_token_targets_padded[i, :stop_len] = stop.to(device) |
|
|
|
return ids, phoneme_seq_padded, mel_specs_padded, mel_lengths, stop_token_targets_padded |
|
|
|
|
|
class EncoderPrenet(torch.nn.Module): |
|
""" |
|
Module for the encoder prenet in the Transformer-based TTS system. |
|
|
|
This module consists of several convolutional layers followed by batch normalization, |
|
ReLU activation, and dropout. It then performs a linear projection to the desired dimension. |
|
|
|
Parameters: |
|
input_dim (int): Dimension of the input features. Defaults to 512. |
|
hidden_dim (int): Dimension of the hidden layers. Defaults to 512. |
|
num_layers (int): Number of convolutional layers. Defaults to 3. |
|
dropout (float): Dropout probability. Defaults to 0.2. |
|
|
|
Inputs: |
|
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim). """ |
|
def __init__(self, input_dim=512, hidden_dim=512, num_layers=3, dropout=0.2): |
|
super().__init__() |
|
|
|
|
|
conv_layers = [] |
|
for _ in range(num_layers): |
|
conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)) |
|
conv_layers.append(nn.BatchNorm1d(hidden_dim)) |
|
conv_layers.append(nn.ReLU()) |
|
conv_layers.append(nn.Dropout(dropout)) |
|
self.conv_layers = nn.Sequential(*conv_layers) |
|
|
|
|
|
self.projection = nn.Linear(hidden_dim, hidden_dim) |
|
|
|
def forward(self, x): |
|
x = x.transpose(1, 2) |
|
x = self.conv_layers(x) |
|
x = x.transpose(1, 2) |
|
x = self.projection(x) |
|
return x |
|
|
|
|
|
class DecoderPrenet(torch.nn.Module): |
|
|
|
""" |
|
Module for the decoder prenet in the Transformer-based TTS system. |
|
|
|
This module consists of two fully connected layers followed by ReLU activation, |
|
and performs a linear projection to the desired output dimension. |
|
|
|
Parameters: |
|
input_dim (int): Dimension of the input features. Defaults to 80. |
|
hidden_dim (int): Dimension of the hidden layers. Defaults to 256. |
|
output_dim (int): Dimension of the output features. Defaults to 512. |
|
|
|
Inputs: |
|
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape (batch_size, seq_len, output_dim). """ |
|
|
|
def __init__(self, input_dim=80, hidden_dim=256, output_dim=512): |
|
super().__init__() |
|
|
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
self.projection = nn.Linear(hidden_dim, output_dim) |
|
|
|
def forward(self, x): |
|
x = x.transpose(1,2) |
|
x = F.relu(self.fc1(x)) |
|
x = F.relu(self.fc2(x)) |
|
x = self.projection(x) |
|
|
|
return x |
|
|
|
|
|
class ScaledPositionalEncoding(nn.Module): |
|
|
|
""" |
|
Module for adding scaled positional encoding to input sequences. |
|
|
|
Parameters: |
|
d_model (int): Dimensionality of the model. It must match the embedding dimension of the input. |
|
max_len (int): Maximum length of the input sequence. Defaults to 5000. |
|
|
|
Inputs: |
|
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor with scaled positional encoding added, shape (batch_size, seq_len, embedding_dim). """ |
|
|
|
def __init__(self, d_model, max_len=5000): |
|
super().__init__() |
|
self.d_model = d_model |
|
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
|
pe = torch.zeros(max_len, 1, d_model) |
|
pe[:, 0, 0::2] = torch.sin(position * div_term) |
|
pe[:, 0, 1::2] = torch.cos(position * div_term) |
|
|
|
self.register_buffer('pe', pe) |
|
self.scale = nn.Parameter(torch.ones(1)) |
|
|
|
def forward(self, x): |
|
""" |
|
Adds scaled positional encoding to input tensor x. |
|
Args: |
|
x: Tensor of shape [batch_size, seq_len, embedding_dim] |
|
""" |
|
scaled_pe = self.pe[:x.size(0), :, :] * self.scale |
|
x = x + scaled_pe |
|
return x |
|
|
|
|
|
class PostNet(torch.nn.Module): |
|
|
|
""" |
|
Post-processing network for mel-spectrogram enhancement. |
|
|
|
This module consists of multiple convolutional layers with batch normalization and ReLU activation. |
|
It is used to refine the mel-spectrogram output from the decoder. |
|
|
|
Parameters: |
|
mel_channels (int): Number of mel channels in the input mel-spectrogram. |
|
postnet_channels (int): Number of channels in the postnet layers. |
|
kernel_size (int): Size of the convolutional kernel. |
|
postnet_layers (int): Number of postnet layers. |
|
|
|
Inputs: |
|
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, mel_channels). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor with refined mel-spectrogram, shape (batch_size, seq_len, mel_channels). """ |
|
|
|
|
|
def __init__(self, mel_channels, postnet_channels, kernel_size, postnet_layers): |
|
super().__init__() |
|
self.conv_layers = nn.ModuleList() |
|
|
|
|
|
self.conv_layers.append( |
|
nn.Sequential( |
|
nn.Conv1d(mel_channels, postnet_channels, kernel_size, padding=kernel_size // 2), |
|
nn.BatchNorm1d(postnet_channels), |
|
nn.ReLU() |
|
) |
|
) |
|
|
|
|
|
for _ in range(1, postnet_layers - 1): |
|
self.conv_layers.append( |
|
nn.Sequential( |
|
nn.Conv1d(postnet_channels, postnet_channels, kernel_size, padding=kernel_size // 2), |
|
nn.BatchNorm1d(postnet_channels), |
|
nn.ReLU() |
|
) |
|
) |
|
|
|
|
|
self.conv_layers.append( |
|
nn.Sequential( |
|
nn.Conv1d(postnet_channels, mel_channels, kernel_size, padding=kernel_size // 2), |
|
nn.BatchNorm1d(mel_channels) |
|
) |
|
) |
|
|
|
def forward(self, x): |
|
x = x.transpose(1, 2) |
|
for conv in self.conv_layers: |
|
x = conv(x) |
|
x = x.transpose(1, 2) |
|
return x |
|
|