|
import math |
|
import torch |
|
from torch import nn |
|
from torch.nn import TransformerEncoder |
|
import torch.nn.functional as F |
|
from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock |
|
|
|
class ASRCNN(nn.Module): |
|
def __init__(self, |
|
input_dim=80, |
|
hidden_dim=256, |
|
n_token=35, |
|
n_layers=6, |
|
token_embedding_dim=256, |
|
|
|
): |
|
super().__init__() |
|
self.n_token = n_token |
|
self.n_down = 1 |
|
self.to_mfcc = MFCC() |
|
self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2) |
|
self.cnns = nn.Sequential( |
|
*[nn.Sequential( |
|
ConvBlock(hidden_dim), |
|
nn.GroupNorm(num_groups=1, num_channels=hidden_dim) |
|
) for n in range(n_layers)]) |
|
self.projection = ConvNorm(hidden_dim, hidden_dim // 2) |
|
self.ctc_linear = nn.Sequential( |
|
LinearNorm(hidden_dim//2, hidden_dim), |
|
nn.ReLU(), |
|
LinearNorm(hidden_dim, n_token)) |
|
self.asr_s2s = ASRS2S( |
|
embedding_dim=token_embedding_dim, |
|
hidden_dim=hidden_dim//2, |
|
n_token=n_token) |
|
|
|
def forward(self, x, src_key_padding_mask=None, text_input=None): |
|
x = self.to_mfcc(x) |
|
x = self.init_cnn(x) |
|
x = self.cnns(x) |
|
x = self.projection(x) |
|
x = x.transpose(1, 2) |
|
ctc_logit = self.ctc_linear(x) |
|
if text_input is not None: |
|
_, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input) |
|
return ctc_logit, s2s_logit, s2s_attn |
|
else: |
|
return ctc_logit |
|
|
|
def get_feature(self, x): |
|
x = self.to_mfcc(x.squeeze(1)) |
|
x = self.init_cnn(x) |
|
x = self.cnns(x) |
|
x = self.projection(x) |
|
return x |
|
|
|
def length_to_mask(self, lengths): |
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) |
|
mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device) |
|
return mask |
|
|
|
def get_future_mask(self, out_length, unmask_future_steps=0): |
|
""" |
|
Args: |
|
out_length (int): returned mask shape is (out_length, out_length). |
|
unmask_futre_steps (int): unmasking future step size. |
|
Return: |
|
mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False |
|
""" |
|
index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1) |
|
mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps) |
|
return mask |
|
|
|
class ASRS2S(nn.Module): |
|
def __init__(self, |
|
embedding_dim=256, |
|
hidden_dim=512, |
|
n_location_filters=32, |
|
location_kernel_size=63, |
|
n_token=40): |
|
super(ASRS2S, self).__init__() |
|
self.embedding = nn.Embedding(n_token, embedding_dim) |
|
val_range = math.sqrt(6 / hidden_dim) |
|
self.embedding.weight.data.uniform_(-val_range, val_range) |
|
|
|
self.decoder_rnn_dim = hidden_dim |
|
self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) |
|
self.attention_layer = Attention( |
|
self.decoder_rnn_dim, |
|
hidden_dim, |
|
hidden_dim, |
|
n_location_filters, |
|
location_kernel_size |
|
) |
|
self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim) |
|
self.project_to_hidden = nn.Sequential( |
|
LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), |
|
nn.Tanh()) |
|
self.sos = 1 |
|
self.eos = 2 |
|
|
|
def initialize_decoder_states(self, memory, mask): |
|
""" |
|
moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) |
|
""" |
|
B, L, H = memory.shape |
|
self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) |
|
self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) |
|
self.attention_weights = torch.zeros((B, L)).type_as(memory) |
|
self.attention_weights_cum = torch.zeros((B, L)).type_as(memory) |
|
self.attention_context = torch.zeros((B, H)).type_as(memory) |
|
self.memory = memory |
|
self.processed_memory = self.attention_layer.memory_layer(memory) |
|
self.mask = mask |
|
self.unk_index = 3 |
|
self.random_mask = 0.1 |
|
|
|
def forward(self, memory, memory_mask, text_input): |
|
""" |
|
moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) |
|
moemory_mask.shape = (B, L, ) |
|
texts_input.shape = (B, T) |
|
""" |
|
self.initialize_decoder_states(memory, memory_mask) |
|
|
|
random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device) |
|
_text_input = text_input.clone() |
|
_text_input.masked_fill_(random_mask, self.unk_index) |
|
decoder_inputs = self.embedding(_text_input).transpose(0, 1) |
|
start_embedding = self.embedding( |
|
torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device)) |
|
decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0) |
|
|
|
hidden_outputs, logit_outputs, alignments = [], [], [] |
|
while len(hidden_outputs) < decoder_inputs.size(0): |
|
|
|
decoder_input = decoder_inputs[len(hidden_outputs)] |
|
hidden, logit, attention_weights = self.decode(decoder_input) |
|
hidden_outputs += [hidden] |
|
logit_outputs += [logit] |
|
alignments += [attention_weights] |
|
|
|
hidden_outputs, logit_outputs, alignments = \ |
|
self.parse_decoder_outputs( |
|
hidden_outputs, logit_outputs, alignments) |
|
|
|
return hidden_outputs, logit_outputs, alignments |
|
|
|
|
|
def decode(self, decoder_input): |
|
|
|
cell_input = torch.cat((decoder_input, self.attention_context), -1) |
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( |
|
cell_input, |
|
(self.decoder_hidden, self.decoder_cell)) |
|
|
|
attention_weights_cat = torch.cat( |
|
(self.attention_weights.unsqueeze(1), |
|
self.attention_weights_cum.unsqueeze(1)),dim=1) |
|
|
|
self.attention_context, self.attention_weights = self.attention_layer( |
|
self.decoder_hidden, |
|
self.memory, |
|
self.processed_memory, |
|
attention_weights_cat, |
|
self.mask) |
|
|
|
self.attention_weights_cum += self.attention_weights |
|
|
|
hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1) |
|
hidden = self.project_to_hidden(hidden_and_context) |
|
|
|
|
|
logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training)) |
|
|
|
return hidden, logit, self.attention_weights |
|
|
|
def parse_decoder_outputs(self, hidden, logit, alignments): |
|
|
|
|
|
alignments = torch.stack(alignments).transpose(0,1) |
|
|
|
logit = torch.stack(logit).transpose(0, 1).contiguous() |
|
hidden = torch.stack(hidden).transpose(0, 1).contiguous() |
|
|
|
return hidden, logit, alignments |
|
|