Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class ConvNorm(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, | |
padding=None, dilation=1, bias=True, w_init_gain='linear'): | |
super(ConvNorm, self).__init__() | |
if padding is None: | |
assert(kernel_size % 2 == 1) | |
padding = int(dilation * (kernel_size - 1) / 2) | |
self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
kernel_size=kernel_size, stride=stride, | |
padding=padding, dilation=dilation, | |
bias=bias) | |
torch.nn.init.xavier_uniform_( | |
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) | |
def forward(self, signal): | |
conv_signal = self.conv(signal) | |
return conv_signal | |
class Invertible1x1ConvLUS(torch.nn.Module): | |
def __init__(self, c): | |
super(Invertible1x1ConvLUS, self).__init__() | |
# Sample a random orthonormal matrix to initialize weights | |
W, _ = torch.linalg.qr(torch.randn(c, c)) | |
# Ensure determinant is 1.0 not -1.0 | |
if torch.det(W) < 0: | |
W[:, 0] = -1*W[:, 0] | |
p, lower, upper = torch.lu_unpack(*torch.lu(W)) | |
self.register_buffer('p', p) | |
# diagonals of lower will always be 1s anyway | |
lower = torch.tril(lower, -1) | |
lower_diag = torch.diag(torch.eye(c, c)) | |
self.register_buffer('lower_diag', lower_diag) | |
self.lower = nn.Parameter(lower) | |
self.upper_diag = nn.Parameter(torch.diag(upper)) | |
self.upper = nn.Parameter(torch.triu(upper, 1)) | |
def forward(self, z, reverse=False): | |
U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag) | |
L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag) | |
W = torch.mm(self.p, torch.mm(L, U)) | |
if reverse: | |
if not hasattr(self, 'W_inverse'): | |
# Reverse computation | |
W_inverse = W.float().inverse() | |
if z.type() == 'torch.cuda.HalfTensor': | |
W_inverse = W_inverse.half() | |
self.W_inverse = W_inverse[..., None] | |
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) | |
return z | |
else: | |
W = W[..., None] | |
z = F.conv1d(z, W, bias=None, stride=1, padding=0) | |
log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag))) | |
return z, log_det_W | |
class ConvAttention(torch.nn.Module): | |
def __init__(self, n_mel_channels=80, n_speaker_dim=128, | |
n_text_channels=512, n_att_channels=80, temperature=1.0, | |
n_mel_convs=2, align_query_enc_type='3xconv', | |
use_query_proj=True): | |
super(ConvAttention, self).__init__() | |
self.temperature = temperature | |
self.att_scaling_factor = np.sqrt(n_att_channels) | |
self.softmax = torch.nn.Softmax(dim=3) | |
self.log_softmax = torch.nn.LogSoftmax(dim=3) | |
self.query_proj = Invertible1x1ConvLUS(n_mel_channels) | |
self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1) | |
self.align_query_enc_type = align_query_enc_type | |
self.use_query_proj = bool(use_query_proj) | |
self.key_proj = nn.Sequential( | |
ConvNorm(n_text_channels, | |
n_text_channels * 2, | |
kernel_size=3, | |
bias=True, | |
w_init_gain='relu'), | |
torch.nn.ReLU(), | |
ConvNorm(n_text_channels * 2, | |
n_att_channels, | |
kernel_size=1, | |
bias=True)) | |
self.align_query_enc_type = align_query_enc_type | |
if align_query_enc_type == "inv_conv": | |
self.query_proj = Invertible1x1ConvLUS(n_mel_channels) | |
elif align_query_enc_type == "3xconv": | |
self.query_proj = nn.Sequential( | |
ConvNorm(n_mel_channels, | |
n_mel_channels * 2, | |
kernel_size=3, | |
bias=True, | |
w_init_gain='relu'), | |
torch.nn.ReLU(), | |
ConvNorm(n_mel_channels * 2, | |
n_mel_channels, | |
kernel_size=1, | |
bias=True), | |
torch.nn.ReLU(), | |
ConvNorm(n_mel_channels, | |
n_att_channels, | |
kernel_size=1, | |
bias=True)) | |
else: | |
raise ValueError("Unknown query encoder type specified") | |
def run_padded_sequence(self, sorted_idx, unsort_idx, lens, padded_data, | |
recurrent_model): | |
"""Sorts input data by previded ordering (and un-ordering) and runs the | |
packed data through the recurrent model | |
Args: | |
sorted_idx (torch.tensor): 1D sorting index | |
unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx) | |
lens: lengths of input data (sorted in descending order) | |
padded_data (torch.tensor): input sequences (padded) | |
recurrent_model (nn.Module): recurrent model to run data through | |
Returns: | |
hidden_vectors (torch.tensor): outputs of the RNN, in the original, | |
unsorted, ordering | |
""" | |
# sort the data by decreasing length using provided index | |
# we assume batch index is in dim=1 | |
padded_data = padded_data[:, sorted_idx] | |
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens) | |
hidden_vectors = recurrent_model(padded_data)[0] | |
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors) | |
# unsort the results at dim=1 and return | |
hidden_vectors = hidden_vectors[:, unsort_idx] | |
return hidden_vectors | |
def encode_query(self, query, query_lens): | |
query = query.permute(2, 0, 1) # seq_len, batch, feature dim | |
lens, ids = torch.sort(query_lens, descending=True) | |
original_ids = [0] * lens.size(0) | |
for i in range(len(ids)): | |
original_ids[ids[i]] = i | |
query_encoded = self.run_padded_sequence(ids, original_ids, lens, | |
query, self.query_lstm) | |
query_encoded = query_encoded.permute(1, 2, 0) | |
return query_encoded | |
def forward(self, queries, keys, query_lens, mask=None, key_lens=None, | |
keys_encoded=None, attn_prior=None): | |
"""Attention mechanism for flowtron parallel | |
Unlike in Flowtron, we have no restrictions such as causality etc, | |
since we only need this during training. | |
Args: | |
queries (torch.tensor): B x C x T1 tensor | |
(probably going to be mel data) | |
keys (torch.tensor): B x C2 x T2 tensor (text data) | |
query_lens: lengths for sorting the queries in descending order | |
mask (torch.tensor): uint8 binary mask for variable length entries | |
(should be in the T2 domain) | |
Output: | |
attn (torch.tensor): B x 1 x T1 x T2 attention mask. | |
Final dim T2 should sum to 1 | |
""" | |
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 | |
# Beware can only do this since query_dim = attn_dim = n_mel_channels | |
if self.use_query_proj: | |
if self.align_query_enc_type == "inv_conv": | |
queries_enc, log_det_W = self.query_proj(queries) | |
elif self.align_query_enc_type == "3xconv": | |
queries_enc = self.query_proj(queries) | |
log_det_W = 0.0 | |
else: | |
queries_enc, log_det_W = self.query_proj(queries) | |
else: | |
queries_enc, log_det_W = queries, 0.0 | |
# different ways of computing attn, | |
# one is isotopic gaussians (per phoneme) | |
# Simplistic Gaussian Isotopic Attention | |
# B x n_attn_dims x T1 x T2 | |
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 | |
# compute log likelihood from a gaussian | |
attn = -0.0005 * attn.sum(1, keepdim=True) | |
if attn_prior is not None: | |
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]+1e-8) | |
attn_logprob = attn.clone() | |
if mask is not None: | |
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), | |
-float("inf")) | |
attn = self.softmax(attn) # Softmax along T2 | |
return attn, attn_logprob | |