|
|
|
import re |
|
import logging |
|
import torch |
|
import torchaudio |
|
import random |
|
import speechbrain as sb |
|
import torch as nn |
|
from speechbrain.utils.fetching import fetch |
|
from speechbrain.inference.interfaces import Pretrained |
|
from speechbrain.inference.text import GraphemeToPhoneme |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TTSModel(Pretrained): |
|
""" |
|
A ready-to-use wrapper for Transformer TTS (text -> mel_spec). |
|
Arguments |
|
--------- |
|
hparams |
|
Hyperparameters (from HyperPyYAML)""" |
|
|
|
HPARAMS_NEEDED = ["model", "blank_index", "padding_mask", "lookahead_mask", "mel_spec_feats", "label_encoder"] |
|
MODULES_NEEDED = ["modules"] |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.label_encoder = self.hparams.label_encoder |
|
|
|
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") |
|
|
|
|
|
def text_to_phoneme(self, text): |
|
""" |
|
Generates phoneme sequences for the given text using a Grapheme-to-Phoneme (G2P) model. |
|
|
|
Args: |
|
text (str): The input text. |
|
|
|
Returns: |
|
list: List of phoneme sequences for the words in the text. |
|
""" |
|
abbreviation_expansions = { |
|
"Mr.": "Mister", |
|
"Mrs.": "Misess", |
|
"Dr.": "Doctor", |
|
"No.": "Number", |
|
"St.": "Saint", |
|
"Co.": "Company", |
|
"Jr.": "Junior", |
|
"Maj.": "Major", |
|
"Gen.": "General", |
|
"Drs.": "Doctors", |
|
"Rev.": "Reverend", |
|
"Lt.": "Lieutenant", |
|
"Hon.": "Honorable", |
|
"Sgt.": "Sergeant", |
|
"Capt.": "Captain", |
|
"Esq.": "Esquire", |
|
"Ltd.": "Limited", |
|
"Col.": "Colonel", |
|
"Ft.": "Fort" |
|
} |
|
|
|
|
|
for abbreviation, expansion in abbreviation_expansions.items(): |
|
text = text.replace(abbreviation, expansion) |
|
|
|
phonemes = self.g2p(text) |
|
phonemes = self.label_encoder.encode_sequence(phonemes) |
|
phoneme_seq = torch.LongTensor(phonemes) |
|
|
|
return phoneme_seq, len(phoneme_seq) |
|
|
|
def encode_batch(self, texts): |
|
"""Computes mel-spectrogram for a list of texts |
|
|
|
Texts must be sorted in decreasing order on their lengths |
|
|
|
Arguments |
|
--------- |
|
texts: List[str] |
|
texts to be encoded into spectrogram |
|
|
|
Returns |
|
------- |
|
tensors of output spectrograms, output lengths and alignments |
|
""" |
|
with torch.no_grad(): |
|
phoneme_seqs = [self.text_to_phoneme(text)[0] for text in texts] |
|
phoneme_seqs_padded, input_lengths = self.pad_sequences(phoneme_seqs) |
|
|
|
encoded_phoneme = self.mods.encoder_emb(phoneme_seqs_padded) |
|
encoder_emb = self.mods.enc_pre_net(encoded_phoneme) |
|
pos_emb_enc = self.mods.pos_emb_enc(encoder_emb) |
|
encoder_emb = encoder_emb + pos_emb_enc |
|
|
|
|
|
stop_generated = False |
|
decoder_input = torch.zeros(1, 80, 1, device=self.device) |
|
stop_tokens_logits = [] |
|
max_generation_length = 1000 |
|
sequence_length = 0 |
|
|
|
result = [] |
|
result.append(decoder_input) |
|
|
|
src_mask = torch.zeros(encoder_emb.size(1), encoder_emb.size(1), device=self.device) |
|
src_key_padding_mask = self.hparams.padding_mask(encoder_emb, self.hparams.blank_index) |
|
|
|
|
|
while not stop_generated and sequence_length < max_generation_length: |
|
encoded_mel = self.mods.dec_pre_net(decoder_input) |
|
pos_emb_dec = self.mods.pos_emb_dec(encoded_mel) |
|
decoder_emb = encoded_mel + pos_emb_dec |
|
|
|
decoder_output = self.mods.Seq2SeqTransformer( |
|
encoder_emb, decoder_emb, src_mask=src_mask, |
|
src_key_padding_mask=src_key_padding_mask) |
|
|
|
mel_output = self.mods.mel_lin(decoder_output) |
|
|
|
stop_token_logit = self.mods.stop_lin(decoder_output).squeeze(-1) |
|
|
|
post_mel_outputs = self.mods.postnet(mel_output.to(self.device)) |
|
refined_mel_output = mel_output + post_mel_outputs.to(self.device) |
|
refined_mel_output = refined_mel_output.transpose(1, 2) |
|
|
|
stop_tokens_logits.append(stop_token_logit) |
|
stop_token_probs = torch.sigmoid(stop_token_logit) |
|
|
|
if torch.any(stop_token_probs[:, -1] >= self.hparams.stop_threshold): |
|
stop_generated = True |
|
|
|
decoder_input = refined_mel_output |
|
result.append(decoder_input) |
|
sequence_length += 1 |
|
|
|
results = torch.cat(result, dim=2) |
|
stop_tokens_logits = torch.cat(stop_tokens_logits, dim=1) |
|
|
|
return results |
|
|
|
def pad_sequences(self, sequences): |
|
"""Pad sequences to the maximum length sequence in the batch. |
|
|
|
Arguments |
|
--------- |
|
sequences: List[torch.Tensor] |
|
The sequences to pad |
|
|
|
Returns |
|
------- |
|
Padded sequences and original lengths |
|
""" |
|
max_length = max([len(seq) for seq in sequences]) |
|
padded_seqs = torch.zeros(len(sequences), max_length, dtype=torch.long) |
|
lengths = [] |
|
for i, seq in enumerate(sequences): |
|
length = len(seq) |
|
padded_seqs[i, :length] = seq |
|
lengths.append(length) |
|
return padded_seqs, torch.tensor(lengths) |
|
|
|
def encode_text(self, text): |
|
"""Runs inference for a single text str""" |
|
return self.encode_batch(text) |
|
|
|
def forward(self, texts): |
|
"Encodes the input texts." |
|
return self.encode_batch(texts) |
|
|