|
import math |
|
import torch |
|
from torch import nn |
|
from typing import Optional, Any |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
import torchaudio |
|
import torchaudio.functional as audio_F |
|
|
|
import random |
|
random.seed(0) |
|
|
|
|
|
def _get_activation_fn(activ): |
|
if activ == 'relu': |
|
return nn.ReLU() |
|
elif activ == 'lrelu': |
|
return nn.LeakyReLU(0.2) |
|
elif activ == 'swish': |
|
return lambda x: x*torch.sigmoid(x) |
|
else: |
|
raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) |
|
|
|
class LinearNorm(torch.nn.Module): |
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): |
|
super(LinearNorm, self).__init__() |
|
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) |
|
|
|
torch.nn.init.xavier_uniform_( |
|
self.linear_layer.weight, |
|
gain=torch.nn.init.calculate_gain(w_init_gain)) |
|
|
|
def forward(self, x): |
|
return self.linear_layer(x) |
|
|
|
|
|
class ConvNorm(torch.nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, |
|
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): |
|
super(ConvNorm, self).__init__() |
|
if padding is None: |
|
assert(kernel_size % 2 == 1) |
|
padding = int(dilation * (kernel_size - 1) / 2) |
|
|
|
self.conv = torch.nn.Conv1d(in_channels, out_channels, |
|
kernel_size=kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, |
|
bias=bias) |
|
|
|
torch.nn.init.xavier_uniform_( |
|
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) |
|
|
|
def forward(self, signal): |
|
conv_signal = self.conv(signal) |
|
return conv_signal |
|
|
|
class CausualConv(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None): |
|
super(CausualConv, self).__init__() |
|
if padding is None: |
|
assert(kernel_size % 2 == 1) |
|
padding = int(dilation * (kernel_size - 1) / 2) * 2 |
|
else: |
|
self.padding = padding * 2 |
|
self.conv = nn.Conv1d(in_channels, out_channels, |
|
kernel_size=kernel_size, stride=stride, |
|
padding=self.padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
torch.nn.init.xavier_uniform_( |
|
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = x[:, :, :-self.padding] |
|
return x |
|
|
|
class CausualBlock(nn.Module): |
|
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'): |
|
super(CausualBlock, self).__init__() |
|
self.blocks = nn.ModuleList([ |
|
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) |
|
for i in range(n_conv)]) |
|
|
|
def forward(self, x): |
|
for block in self.blocks: |
|
res = x |
|
x = block(x) |
|
x += res |
|
return x |
|
|
|
def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2): |
|
layers = [ |
|
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), |
|
_get_activation_fn(activ), |
|
nn.BatchNorm1d(hidden_dim), |
|
nn.Dropout(p=dropout_p), |
|
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), |
|
_get_activation_fn(activ), |
|
nn.Dropout(p=dropout_p) |
|
] |
|
return nn.Sequential(*layers) |
|
|
|
class ConvBlock(nn.Module): |
|
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'): |
|
super().__init__() |
|
self._n_groups = 8 |
|
self.blocks = nn.ModuleList([ |
|
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) |
|
for i in range(n_conv)]) |
|
|
|
|
|
def forward(self, x): |
|
for block in self.blocks: |
|
res = x |
|
x = block(x) |
|
x += res |
|
return x |
|
|
|
def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2): |
|
layers = [ |
|
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), |
|
_get_activation_fn(activ), |
|
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), |
|
nn.Dropout(p=dropout_p), |
|
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), |
|
_get_activation_fn(activ), |
|
nn.Dropout(p=dropout_p) |
|
] |
|
return nn.Sequential(*layers) |
|
|
|
class LocationLayer(nn.Module): |
|
def __init__(self, attention_n_filters, attention_kernel_size, |
|
attention_dim): |
|
super(LocationLayer, self).__init__() |
|
padding = int((attention_kernel_size - 1) / 2) |
|
self.location_conv = ConvNorm(2, attention_n_filters, |
|
kernel_size=attention_kernel_size, |
|
padding=padding, bias=False, stride=1, |
|
dilation=1) |
|
self.location_dense = LinearNorm(attention_n_filters, attention_dim, |
|
bias=False, w_init_gain='tanh') |
|
|
|
def forward(self, attention_weights_cat): |
|
processed_attention = self.location_conv(attention_weights_cat) |
|
processed_attention = processed_attention.transpose(1, 2) |
|
processed_attention = self.location_dense(processed_attention) |
|
return processed_attention |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, |
|
attention_location_n_filters, attention_location_kernel_size): |
|
super(Attention, self).__init__() |
|
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, |
|
bias=False, w_init_gain='tanh') |
|
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, |
|
w_init_gain='tanh') |
|
self.v = LinearNorm(attention_dim, 1, bias=False) |
|
self.location_layer = LocationLayer(attention_location_n_filters, |
|
attention_location_kernel_size, |
|
attention_dim) |
|
self.score_mask_value = -float("inf") |
|
|
|
def get_alignment_energies(self, query, processed_memory, |
|
attention_weights_cat): |
|
""" |
|
PARAMS |
|
------ |
|
query: decoder output (batch, n_mel_channels * n_frames_per_step) |
|
processed_memory: processed encoder outputs (B, T_in, attention_dim) |
|
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) |
|
RETURNS |
|
------- |
|
alignment (batch, max_time) |
|
""" |
|
|
|
processed_query = self.query_layer(query.unsqueeze(1)) |
|
processed_attention_weights = self.location_layer(attention_weights_cat) |
|
energies = self.v(torch.tanh( |
|
processed_query + processed_attention_weights + processed_memory)) |
|
|
|
energies = energies.squeeze(-1) |
|
return energies |
|
|
|
def forward(self, attention_hidden_state, memory, processed_memory, |
|
attention_weights_cat, mask): |
|
""" |
|
PARAMS |
|
------ |
|
attention_hidden_state: attention rnn last output |
|
memory: encoder outputs |
|
processed_memory: processed encoder outputs |
|
attention_weights_cat: previous and cummulative attention weights |
|
mask: binary mask for padded data |
|
""" |
|
alignment = self.get_alignment_energies( |
|
attention_hidden_state, processed_memory, attention_weights_cat) |
|
|
|
if mask is not None: |
|
alignment.data.masked_fill_(mask, self.score_mask_value) |
|
|
|
attention_weights = F.softmax(alignment, dim=1) |
|
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) |
|
attention_context = attention_context.squeeze(1) |
|
|
|
return attention_context, attention_weights |
|
|
|
|
|
class ForwardAttentionV2(nn.Module): |
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, |
|
attention_location_n_filters, attention_location_kernel_size): |
|
super(ForwardAttentionV2, self).__init__() |
|
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, |
|
bias=False, w_init_gain='tanh') |
|
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, |
|
w_init_gain='tanh') |
|
self.v = LinearNorm(attention_dim, 1, bias=False) |
|
self.location_layer = LocationLayer(attention_location_n_filters, |
|
attention_location_kernel_size, |
|
attention_dim) |
|
self.score_mask_value = -float(1e20) |
|
|
|
def get_alignment_energies(self, query, processed_memory, |
|
attention_weights_cat): |
|
""" |
|
PARAMS |
|
------ |
|
query: decoder output (batch, n_mel_channels * n_frames_per_step) |
|
processed_memory: processed encoder outputs (B, T_in, attention_dim) |
|
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) |
|
RETURNS |
|
------- |
|
alignment (batch, max_time) |
|
""" |
|
|
|
processed_query = self.query_layer(query.unsqueeze(1)) |
|
processed_attention_weights = self.location_layer(attention_weights_cat) |
|
energies = self.v(torch.tanh( |
|
processed_query + processed_attention_weights + processed_memory)) |
|
|
|
energies = energies.squeeze(-1) |
|
return energies |
|
|
|
def forward(self, attention_hidden_state, memory, processed_memory, |
|
attention_weights_cat, mask, log_alpha): |
|
""" |
|
PARAMS |
|
------ |
|
attention_hidden_state: attention rnn last output |
|
memory: encoder outputs |
|
processed_memory: processed encoder outputs |
|
attention_weights_cat: previous and cummulative attention weights |
|
mask: binary mask for padded data |
|
""" |
|
log_energy = self.get_alignment_energies( |
|
attention_hidden_state, processed_memory, attention_weights_cat) |
|
|
|
|
|
|
|
if mask is not None: |
|
log_energy.data.masked_fill_(mask, self.score_mask_value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_alpha_shift_padded = [] |
|
max_time = log_energy.size(1) |
|
for sft in range(2): |
|
shifted = log_alpha[:,:max_time-sft] |
|
shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) |
|
log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) |
|
|
|
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) |
|
|
|
log_alpha_new = biased + log_energy |
|
|
|
attention_weights = F.softmax(log_alpha_new, dim=1) |
|
|
|
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) |
|
attention_context = attention_context.squeeze(1) |
|
|
|
return attention_context, attention_weights, log_alpha_new |
|
|
|
|
|
class PhaseShuffle2d(nn.Module): |
|
def __init__(self, n=2): |
|
super(PhaseShuffle2d, self).__init__() |
|
self.n = n |
|
self.random = random.Random(1) |
|
|
|
def forward(self, x, move=None): |
|
|
|
if move is None: |
|
move = self.random.randint(-self.n, self.n) |
|
|
|
if move == 0: |
|
return x |
|
else: |
|
left = x[:, :, :, :move] |
|
right = x[:, :, :, move:] |
|
shuffled = torch.cat([right, left], dim=3) |
|
return shuffled |
|
|
|
class PhaseShuffle1d(nn.Module): |
|
def __init__(self, n=2): |
|
super(PhaseShuffle1d, self).__init__() |
|
self.n = n |
|
self.random = random.Random(1) |
|
|
|
def forward(self, x, move=None): |
|
|
|
if move is None: |
|
move = self.random.randint(-self.n, self.n) |
|
|
|
if move == 0: |
|
return x |
|
else: |
|
left = x[:, :, :move] |
|
right = x[:, :, move:] |
|
shuffled = torch.cat([right, left], dim=2) |
|
|
|
return shuffled |
|
|
|
class MFCC(nn.Module): |
|
def __init__(self, n_mfcc=40, n_mels=80): |
|
super(MFCC, self).__init__() |
|
self.n_mfcc = n_mfcc |
|
self.n_mels = n_mels |
|
self.norm = 'ortho' |
|
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) |
|
self.register_buffer('dct_mat', dct_mat) |
|
|
|
def forward(self, mel_specgram): |
|
if len(mel_specgram.shape) == 2: |
|
mel_specgram = mel_specgram.unsqueeze(0) |
|
unsqueezed = True |
|
else: |
|
unsqueezed = False |
|
|
|
|
|
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) |
|
|
|
|
|
if unsqueezed: |
|
mfcc = mfcc.squeeze(0) |
|
return mfcc |
|
|