Krisshvamsi
commited on
Commit
•
ce88638
1
Parent(s):
4d0d969
Upload 3 files
Browse files- TTSModel.py +173 -0
- hyperparams.yaml +173 -0
- label_encoder.txt +46 -0
TTSModel.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import re
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import random
|
7 |
+
import speechbrain as sb
|
8 |
+
import torch as nn
|
9 |
+
from speechbrain.utils.fetching import fetch
|
10 |
+
from speechbrain.inference.interfaces import Pretrained
|
11 |
+
from speechbrain.inference.text import GraphemeToPhoneme
|
12 |
+
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class TTSModel(Pretrained):
|
18 |
+
"""
|
19 |
+
A ready-to-use wrapper for Transformer TTS (text -> mel_spec).
|
20 |
+
Arguments
|
21 |
+
---------
|
22 |
+
hparams
|
23 |
+
Hyperparameters (from HyperPyYAML)"""
|
24 |
+
|
25 |
+
HPARAMS_NEEDED = ["model", "blank_index", "padding_mask", "lookahead_mask", "mel_spec_feats", "label_encoder"]
|
26 |
+
MODULES_NEEDED = ["modules"]
|
27 |
+
|
28 |
+
def __init__(self, *args, **kwargs):
|
29 |
+
super().__init__(*args, **kwargs)
|
30 |
+
self.label_encoder = self.hparams.label_encoder
|
31 |
+
#self.label_encoder.update_from_iterable(self.hparams["lexicon"], sequence_input=False)
|
32 |
+
self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
|
33 |
+
|
34 |
+
|
35 |
+
def text_to_phoneme(self, text):
|
36 |
+
"""
|
37 |
+
Generates phoneme sequences for the given text using a Grapheme-to-Phoneme (G2P) model.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
text (str): The input text.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
list: List of phoneme sequences for the words in the text.
|
44 |
+
"""
|
45 |
+
abbreviation_expansions = {
|
46 |
+
"Mr.": "Mister",
|
47 |
+
"Mrs.": "Misess",
|
48 |
+
"Dr.": "Doctor",
|
49 |
+
"No.": "Number",
|
50 |
+
"St.": "Saint",
|
51 |
+
"Co.": "Company",
|
52 |
+
"Jr.": "Junior",
|
53 |
+
"Maj.": "Major",
|
54 |
+
"Gen.": "General",
|
55 |
+
"Drs.": "Doctors",
|
56 |
+
"Rev.": "Reverend",
|
57 |
+
"Lt.": "Lieutenant",
|
58 |
+
"Hon.": "Honorable",
|
59 |
+
"Sgt.": "Sergeant",
|
60 |
+
"Capt.": "Captain",
|
61 |
+
"Esq.": "Esquire",
|
62 |
+
"Ltd.": "Limited",
|
63 |
+
"Col.": "Colonel",
|
64 |
+
"Ft.": "Fort"
|
65 |
+
}
|
66 |
+
|
67 |
+
# Expand abbreviations
|
68 |
+
for abbreviation, expansion in abbreviation_expansions.items():
|
69 |
+
text = text.replace(abbreviation, expansion)
|
70 |
+
|
71 |
+
phonemes = self.g2p(text)
|
72 |
+
phonemes = self.label_encoder.encode_sequence(phonemes)
|
73 |
+
phoneme_seq = torch.LongTensor(phonemes)
|
74 |
+
|
75 |
+
return phoneme_seq, len(phoneme_seq)
|
76 |
+
|
77 |
+
def encode_batch(self, texts):
|
78 |
+
"""Computes mel-spectrogram for a list of texts
|
79 |
+
|
80 |
+
Texts must be sorted in decreasing order on their lengths
|
81 |
+
|
82 |
+
Arguments
|
83 |
+
---------
|
84 |
+
texts: List[str]
|
85 |
+
texts to be encoded into spectrogram
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
tensors of output spectrograms, output lengths and alignments
|
90 |
+
"""
|
91 |
+
with torch.no_grad():
|
92 |
+
phoneme_seqs = [self.text_to_phoneme(text)[0] for text in texts]
|
93 |
+
phoneme_seqs_padded, input_lengths = self.pad_sequences(phoneme_seqs)
|
94 |
+
|
95 |
+
encoded_phoneme = self.mods.encoder_emb(phoneme_seqs_padded)
|
96 |
+
encoder_emb = self.mods.enc_pre_net(encoded_phoneme)
|
97 |
+
pos_emb_enc = self.mods.pos_emb_enc(encoder_emb)
|
98 |
+
encoder_emb = encoder_emb + pos_emb_enc
|
99 |
+
|
100 |
+
|
101 |
+
stop_generated = False
|
102 |
+
decoder_input = torch.zeros(1, 80, 1, device=self.device)
|
103 |
+
stop_tokens_logits = []
|
104 |
+
max_generation_length = 1000
|
105 |
+
sequence_length = 0
|
106 |
+
|
107 |
+
result = []
|
108 |
+
result.append(decoder_input)
|
109 |
+
|
110 |
+
src_mask = torch.zeros(encoder_emb.size(1), encoder_emb.size(1), device=self.device)
|
111 |
+
src_key_padding_mask = self.hparams.padding_mask(encoder_emb, self.hparams.blank_index)
|
112 |
+
|
113 |
+
|
114 |
+
while not stop_generated and sequence_length < max_generation_length:
|
115 |
+
encoded_mel = self.mods.dec_pre_net(decoder_input)
|
116 |
+
pos_emb_dec = self.mods.pos_emb_dec(encoded_mel)
|
117 |
+
decoder_emb = encoded_mel + pos_emb_dec
|
118 |
+
|
119 |
+
decoder_output = self.mods.Seq2SeqTransformer(
|
120 |
+
encoder_emb, decoder_emb, src_mask=src_mask,
|
121 |
+
src_key_padding_mask=src_key_padding_mask)
|
122 |
+
|
123 |
+
mel_output = self.mods.mel_lin(decoder_output)
|
124 |
+
|
125 |
+
stop_token_logit = self.mods.stop_lin(decoder_output).squeeze(-1)
|
126 |
+
|
127 |
+
post_mel_outputs = self.mods.postnet(mel_output.to(self.device))
|
128 |
+
refined_mel_output = mel_output + post_mel_outputs.to(self.device)
|
129 |
+
refined_mel_output = refined_mel_output.transpose(1, 2)
|
130 |
+
|
131 |
+
stop_tokens_logits.append(stop_token_logit)
|
132 |
+
stop_token_probs = torch.sigmoid(stop_token_logit)
|
133 |
+
|
134 |
+
if torch.any(stop_token_probs[:, -1] >= self.hparams.stop_threshold):
|
135 |
+
stop_generated = True
|
136 |
+
|
137 |
+
decoder_input = refined_mel_output
|
138 |
+
result.append(decoder_input)
|
139 |
+
sequence_length += 1
|
140 |
+
|
141 |
+
results = torch.cat(result, dim=2)
|
142 |
+
stop_tokens_logits = torch.cat(stop_tokens_logits, dim=1)
|
143 |
+
|
144 |
+
return results
|
145 |
+
|
146 |
+
def pad_sequences(self, sequences):
|
147 |
+
"""Pad sequences to the maximum length sequence in the batch.
|
148 |
+
|
149 |
+
Arguments
|
150 |
+
---------
|
151 |
+
sequences: List[torch.Tensor]
|
152 |
+
The sequences to pad
|
153 |
+
|
154 |
+
Returns
|
155 |
+
-------
|
156 |
+
Padded sequences and original lengths
|
157 |
+
"""
|
158 |
+
max_length = max([len(seq) for seq in sequences])
|
159 |
+
padded_seqs = torch.zeros(len(sequences), max_length, dtype=torch.long)
|
160 |
+
lengths = []
|
161 |
+
for i, seq in enumerate(sequences):
|
162 |
+
length = len(seq)
|
163 |
+
padded_seqs[i, :length] = seq
|
164 |
+
lengths.append(length)
|
165 |
+
return padded_seqs, torch.tensor(lengths)
|
166 |
+
|
167 |
+
def encode_text(self, text):
|
168 |
+
"""Runs inference for a single text str"""
|
169 |
+
return self.encode_batch(text)
|
170 |
+
|
171 |
+
def forward(self, texts):
|
172 |
+
"Encodes the input texts."
|
173 |
+
return self.encode_batch(texts)
|
hyperparams.yaml
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
################################
|
4 |
+
# Audio Parameters #
|
5 |
+
################################
|
6 |
+
sample_rate: 22050
|
7 |
+
hop_length: 256
|
8 |
+
win_length: 1024
|
9 |
+
n_mel_channels: 80
|
10 |
+
n_fft: 1024
|
11 |
+
mel_fmin: 0.0
|
12 |
+
mel_fmax: 8000.0
|
13 |
+
power: 1
|
14 |
+
normalized: False
|
15 |
+
min_max_energy_norm: True
|
16 |
+
norm: "slaney"
|
17 |
+
mel_scale: "slaney"
|
18 |
+
dynamic_range_compression: True
|
19 |
+
mel_normalized: False
|
20 |
+
min_f0: 65 #(torchaudio pyin values)
|
21 |
+
max_f0: 2093 #(torchaudio pyin values)
|
22 |
+
|
23 |
+
positive_weight: 5.0
|
24 |
+
lexicon:
|
25 |
+
- AA
|
26 |
+
- AE
|
27 |
+
- AH
|
28 |
+
- AO
|
29 |
+
- AW
|
30 |
+
- AY
|
31 |
+
- B
|
32 |
+
- CH
|
33 |
+
- D
|
34 |
+
- DH
|
35 |
+
- EH
|
36 |
+
- ER
|
37 |
+
- EY
|
38 |
+
- F
|
39 |
+
- G
|
40 |
+
- HH
|
41 |
+
- IH
|
42 |
+
- IY
|
43 |
+
- JH
|
44 |
+
- K
|
45 |
+
- L
|
46 |
+
- M
|
47 |
+
- N
|
48 |
+
- NG
|
49 |
+
- OW
|
50 |
+
- OY
|
51 |
+
- P
|
52 |
+
- R
|
53 |
+
- S
|
54 |
+
- SH
|
55 |
+
- T
|
56 |
+
- TH
|
57 |
+
- UH
|
58 |
+
- UW
|
59 |
+
- V
|
60 |
+
- W
|
61 |
+
- Y
|
62 |
+
- Z
|
63 |
+
- ZH
|
64 |
+
- ' '
|
65 |
+
n_symbols: 42 #fixed depending on symbols in the lexicon +1 for a dummy symbol used for padding
|
66 |
+
padding_idx: 0
|
67 |
+
|
68 |
+
# Define model architecture
|
69 |
+
d_model: 512
|
70 |
+
nhead: 8
|
71 |
+
num_encoder_layers: 6
|
72 |
+
num_decoder_layers: 6
|
73 |
+
dim_feedforward: 2048
|
74 |
+
dropout: 0.2
|
75 |
+
blank_index: 0 # This special token is for padding
|
76 |
+
bos_index: 1
|
77 |
+
eos_index: 2
|
78 |
+
stop_weight: 0.45
|
79 |
+
stop_threshold: 0.5
|
80 |
+
|
81 |
+
|
82 |
+
###################PRENET#######################
|
83 |
+
enc_pre_net: !new:models.EncoderPrenet
|
84 |
+
dec_pre_net: !new:models.DecoderPrenet
|
85 |
+
|
86 |
+
|
87 |
+
encoder_emb: !new:torch.nn.Embedding
|
88 |
+
num_embeddings: 128
|
89 |
+
embedding_dim: !ref <d_model>
|
90 |
+
padding_idx: !ref <blank_index>
|
91 |
+
|
92 |
+
pos_emb_enc: !new:models.ScaledPositionalEncoding
|
93 |
+
d_model: !ref <d_model>
|
94 |
+
|
95 |
+
decoder_emb: !new:torch.nn.Embedding
|
96 |
+
num_embeddings: 128
|
97 |
+
embedding_dim: !ref <d_model>
|
98 |
+
padding_idx: !ref <blank_index>
|
99 |
+
|
100 |
+
pos_emb_dec: !new:models.ScaledPositionalEncoding
|
101 |
+
d_model: !ref <d_model>
|
102 |
+
|
103 |
+
|
104 |
+
Seq2SeqTransformer: !new:torch.nn.Transformer
|
105 |
+
d_model: !ref <d_model>
|
106 |
+
nhead: !ref <nhead>
|
107 |
+
num_encoder_layers: !ref <num_encoder_layers>
|
108 |
+
num_decoder_layers: !ref <num_decoder_layers>
|
109 |
+
dim_feedforward: !ref <dim_feedforward>
|
110 |
+
dropout: !ref <dropout>
|
111 |
+
batch_first: True
|
112 |
+
|
113 |
+
postnet: !new:models.PostNet
|
114 |
+
mel_channels: !ref <n_mel_channels>
|
115 |
+
postnet_channels: 512
|
116 |
+
kernel_size: 5
|
117 |
+
postnet_layers: 5
|
118 |
+
|
119 |
+
mel_lin: !new:speechbrain.nnet.linear.Linear
|
120 |
+
input_size: !ref <d_model>
|
121 |
+
n_neurons: !ref <n_mel_channels>
|
122 |
+
|
123 |
+
stop_lin: !new:speechbrain.nnet.linear.Linear
|
124 |
+
input_size: !ref <d_model>
|
125 |
+
n_neurons: 1
|
126 |
+
|
127 |
+
mel_spec_feats: !name:speechbrain.lobes.models.FastSpeech2.mel_spectogram
|
128 |
+
sample_rate: !ref <sample_rate>
|
129 |
+
hop_length: !ref <hop_length>
|
130 |
+
win_length: !ref <win_length>
|
131 |
+
n_fft: !ref <n_fft>
|
132 |
+
n_mels: !ref <n_mel_channels>
|
133 |
+
f_min: !ref <mel_fmin>
|
134 |
+
f_max: !ref <mel_fmax>
|
135 |
+
power: !ref <power>
|
136 |
+
normalized: !ref <normalized>
|
137 |
+
min_max_energy_norm: !ref <min_max_energy_norm>
|
138 |
+
norm: !ref <norm>
|
139 |
+
mel_scale: !ref <mel_scale>
|
140 |
+
compression: !ref <dynamic_range_compression>
|
141 |
+
|
142 |
+
modules:
|
143 |
+
enc_pre_net: !ref <enc_pre_net>
|
144 |
+
encoder_emb: !ref <encoder_emb>
|
145 |
+
pos_emb_enc: !ref <pos_emb_enc>
|
146 |
+
|
147 |
+
dec_pre_net: !ref <dec_pre_net>
|
148 |
+
#decoder_emb: !ref <decoder_emb>
|
149 |
+
pos_emb_dec: !ref <pos_emb_dec>
|
150 |
+
|
151 |
+
Seq2SeqTransformer: !ref <Seq2SeqTransformer>
|
152 |
+
postnet: !ref <postnet>
|
153 |
+
mel_lin: !ref <mel_lin>
|
154 |
+
stop_lin: !ref <stop_lin>
|
155 |
+
model: !ref <model>
|
156 |
+
|
157 |
+
lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
|
158 |
+
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
|
159 |
+
|
160 |
+
model: !new:torch.nn.ModuleList
|
161 |
+
- [!ref <enc_pre_net>, !ref <encoder_emb>, !ref <pos_emb_enc>, !ref <dec_pre_net>, !ref <pos_emb_dec>, !ref <Seq2SeqTransformer>, !ref <postnet>, !ref <mel_lin>, !ref <stop_lin>]
|
162 |
+
|
163 |
+
label_encoder: !new:speechbrain.dataio.encoder.TextEncoder
|
164 |
+
|
165 |
+
pretrained_path: /content/
|
166 |
+
|
167 |
+
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
|
168 |
+
loadables:
|
169 |
+
model: !ref <model>
|
170 |
+
label_encoder: !ref <label_encoder>
|
171 |
+
paths:
|
172 |
+
model: !ref <pretrained_path>/model.ckpt
|
173 |
+
label_encoder: !ref <pretrained_path>/label_encoder.txt
|
label_encoder.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'AA' => 0
|
2 |
+
'AE' => 40
|
3 |
+
'AH' => 41
|
4 |
+
'AO' => 3
|
5 |
+
'AW' => 4
|
6 |
+
'AY' => 5
|
7 |
+
'B' => 6
|
8 |
+
'CH' => 7
|
9 |
+
'D' => 8
|
10 |
+
'DH' => 9
|
11 |
+
'EH' => 10
|
12 |
+
'ER' => 11
|
13 |
+
'EY' => 12
|
14 |
+
'F' => 13
|
15 |
+
'G' => 14
|
16 |
+
'HH' => 15
|
17 |
+
'IH' => 16
|
18 |
+
'IY' => 17
|
19 |
+
'JH' => 18
|
20 |
+
'K' => 19
|
21 |
+
'L' => 20
|
22 |
+
'M' => 21
|
23 |
+
'N' => 22
|
24 |
+
'NG' => 23
|
25 |
+
'OW' => 24
|
26 |
+
'OY' => 25
|
27 |
+
'P' => 26
|
28 |
+
'R' => 27
|
29 |
+
'S' => 28
|
30 |
+
'SH' => 29
|
31 |
+
'T' => 30
|
32 |
+
'TH' => 31
|
33 |
+
'UH' => 32
|
34 |
+
'UW' => 33
|
35 |
+
'V' => 34
|
36 |
+
'W' => 35
|
37 |
+
'Y' => 36
|
38 |
+
'Z' => 37
|
39 |
+
'ZH' => 38
|
40 |
+
' ' => 39
|
41 |
+
'<bos>' => 1
|
42 |
+
'<eos>' => 2
|
43 |
+
================
|
44 |
+
'starting_index' => 0
|
45 |
+
'bos_label' => '<bos>'
|
46 |
+
'eos_label' => '<eos>'
|