Spaces:
Runtime error
Runtime error
Florian Lux
commited on
Commit
•
2cb106d
1
Parent(s):
f9463cb
implement the cloning demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +16 -0
- InferenceInterfaces/InferenceArchitectures/InferenceFastSpeech2.py +256 -0
- InferenceInterfaces/InferenceArchitectures/InferenceHiFiGAN.py +91 -0
- InferenceInterfaces/InferenceArchitectures/__init__.py +0 -0
- InferenceInterfaces/Meta_FastSpeech2.py +75 -0
- InferenceInterfaces/__init__.py +0 -0
- Layers/Attention.py +324 -0
- Layers/Conformer.py +144 -0
- Layers/Convolution.py +55 -0
- Layers/DurationPredictor.py +139 -0
- Layers/EncoderLayer.py +144 -0
- Layers/LayerNorm.py +36 -0
- Layers/LengthRegulator.py +62 -0
- Layers/MultiLayeredConv1d.py +87 -0
- Layers/MultiSequential.py +33 -0
- Layers/PositionalEncoding.py +166 -0
- Layers/PositionwiseFeedForward.py +26 -0
- Layers/PostNet.py +74 -0
- Layers/ResidualBlock.py +98 -0
- Layers/ResidualStack.py +51 -0
- Layers/STFT.py +118 -0
- Layers/Swish.py +18 -0
- Layers/VariancePredictor.py +65 -0
- Layers/__init__.py +0 -0
- Models/Aligner/__init__.py +0 -0
- Models/FastSpeech2_Meta/__init__.py +0 -0
- Models/HiFiGAN_combined/__init__.py +0 -0
- Preprocessing/ArticulatoryCombinedTextFrontend.py +323 -0
- Preprocessing/AudioPreprocessor.py +166 -0
- Preprocessing/ProsodicConditionExtractor.py +40 -0
- Preprocessing/__init__.py +0 -0
- Preprocessing/papercup_features.py +637 -0
- README.md +3 -3
- TrainingInterfaces/Text_to_Spectrogram/AutoAligner/Aligner.py +287 -0
- TrainingInterfaces/Text_to_Spectrogram/AutoAligner/AlignerDataset.py +211 -0
- TrainingInterfaces/Text_to_Spectrogram/AutoAligner/TinyTTS.py +36 -0
- TrainingInterfaces/Text_to_Spectrogram/AutoAligner/__init__.py +0 -0
- TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py +145 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/DurationCalculator.py +31 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/EnergyCalculator.py +86 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2.py +379 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2Loss.py +96 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDatasetLanguageID.py +217 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/PitchCalculator.py +121 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/__init__.py +0 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop.py +201 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop_ctc.py +191 -0
- TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py +211 -0
- TrainingInterfaces/Text_to_Spectrogram/__init__.py +0 -0
- TrainingInterfaces/__init__.py +0 -0
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
*.pyc
|
3 |
+
*.png
|
4 |
+
*.pdf
|
5 |
+
*.pt
|
6 |
+
tensorboard_logs
|
7 |
+
Corpora
|
8 |
+
*_graph
|
9 |
+
*.out
|
10 |
+
*.wav
|
11 |
+
*.flac
|
12 |
+
audios/
|
13 |
+
*playground*
|
14 |
+
*.json
|
15 |
+
.tmp/
|
16 |
+
.vscode/
|
InferenceInterfaces/InferenceArchitectures/InferenceFastSpeech2.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from Layers.Conformer import Conformer
|
6 |
+
from Layers.DurationPredictor import DurationPredictor
|
7 |
+
from Layers.LengthRegulator import LengthRegulator
|
8 |
+
from Layers.PostNet import PostNet
|
9 |
+
from Layers.VariancePredictor import VariancePredictor
|
10 |
+
from Utility.utils import make_non_pad_mask
|
11 |
+
from Utility.utils import make_pad_mask
|
12 |
+
|
13 |
+
|
14 |
+
class FastSpeech2(torch.nn.Module, ABC):
|
15 |
+
|
16 |
+
def __init__(self, # network structure related
|
17 |
+
weights,
|
18 |
+
idim=66,
|
19 |
+
odim=80,
|
20 |
+
adim=384,
|
21 |
+
aheads=4,
|
22 |
+
elayers=6,
|
23 |
+
eunits=1536,
|
24 |
+
dlayers=6,
|
25 |
+
dunits=1536,
|
26 |
+
postnet_layers=5,
|
27 |
+
postnet_chans=256,
|
28 |
+
postnet_filts=5,
|
29 |
+
positionwise_conv_kernel_size=1,
|
30 |
+
use_scaled_pos_enc=True,
|
31 |
+
use_batch_norm=True,
|
32 |
+
encoder_normalize_before=True,
|
33 |
+
decoder_normalize_before=True,
|
34 |
+
encoder_concat_after=False,
|
35 |
+
decoder_concat_after=False,
|
36 |
+
reduction_factor=1,
|
37 |
+
# encoder / decoder
|
38 |
+
use_macaron_style_in_conformer=True,
|
39 |
+
use_cnn_in_conformer=True,
|
40 |
+
conformer_enc_kernel_size=7,
|
41 |
+
conformer_dec_kernel_size=31,
|
42 |
+
# duration predictor
|
43 |
+
duration_predictor_layers=2,
|
44 |
+
duration_predictor_chans=256,
|
45 |
+
duration_predictor_kernel_size=3,
|
46 |
+
# energy predictor
|
47 |
+
energy_predictor_layers=2,
|
48 |
+
energy_predictor_chans=256,
|
49 |
+
energy_predictor_kernel_size=3,
|
50 |
+
energy_predictor_dropout=0.5,
|
51 |
+
energy_embed_kernel_size=1,
|
52 |
+
energy_embed_dropout=0.0,
|
53 |
+
stop_gradient_from_energy_predictor=True,
|
54 |
+
# pitch predictor
|
55 |
+
pitch_predictor_layers=5,
|
56 |
+
pitch_predictor_chans=256,
|
57 |
+
pitch_predictor_kernel_size=5,
|
58 |
+
pitch_predictor_dropout=0.5,
|
59 |
+
pitch_embed_kernel_size=1,
|
60 |
+
pitch_embed_dropout=0.0,
|
61 |
+
stop_gradient_from_pitch_predictor=True,
|
62 |
+
# training related
|
63 |
+
transformer_enc_dropout_rate=0.2,
|
64 |
+
transformer_enc_positional_dropout_rate=0.2,
|
65 |
+
transformer_enc_attn_dropout_rate=0.2,
|
66 |
+
transformer_dec_dropout_rate=0.2,
|
67 |
+
transformer_dec_positional_dropout_rate=0.2,
|
68 |
+
transformer_dec_attn_dropout_rate=0.2,
|
69 |
+
duration_predictor_dropout_rate=0.2,
|
70 |
+
postnet_dropout_rate=0.5,
|
71 |
+
# additional features
|
72 |
+
utt_embed_dim=704,
|
73 |
+
connect_utt_emb_at_encoder_out=True,
|
74 |
+
lang_embs=100):
|
75 |
+
super().__init__()
|
76 |
+
self.idim = idim
|
77 |
+
self.odim = odim
|
78 |
+
self.reduction_factor = reduction_factor
|
79 |
+
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
|
80 |
+
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
|
81 |
+
self.use_scaled_pos_enc = use_scaled_pos_enc
|
82 |
+
embed = torch.nn.Sequential(torch.nn.Linear(idim, 100),
|
83 |
+
torch.nn.Tanh(),
|
84 |
+
torch.nn.Linear(100, adim))
|
85 |
+
self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers,
|
86 |
+
input_layer=embed, dropout_rate=transformer_enc_dropout_rate,
|
87 |
+
positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate,
|
88 |
+
normalize_before=encoder_normalize_before, concat_after=encoder_concat_after,
|
89 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer,
|
90 |
+
use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False,
|
91 |
+
utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs)
|
92 |
+
self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers,
|
93 |
+
n_chans=duration_predictor_chans,
|
94 |
+
kernel_size=duration_predictor_kernel_size,
|
95 |
+
dropout_rate=duration_predictor_dropout_rate, )
|
96 |
+
self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers,
|
97 |
+
n_chans=pitch_predictor_chans,
|
98 |
+
kernel_size=pitch_predictor_kernel_size,
|
99 |
+
dropout_rate=pitch_predictor_dropout)
|
100 |
+
self.pitch_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
|
101 |
+
kernel_size=pitch_embed_kernel_size,
|
102 |
+
padding=(pitch_embed_kernel_size - 1) // 2),
|
103 |
+
torch.nn.Dropout(pitch_embed_dropout))
|
104 |
+
self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers,
|
105 |
+
n_chans=energy_predictor_chans,
|
106 |
+
kernel_size=energy_predictor_kernel_size,
|
107 |
+
dropout_rate=energy_predictor_dropout)
|
108 |
+
self.energy_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
|
109 |
+
kernel_size=energy_embed_kernel_size,
|
110 |
+
padding=(energy_embed_kernel_size - 1) // 2),
|
111 |
+
torch.nn.Dropout(energy_embed_dropout))
|
112 |
+
self.length_regulator = LengthRegulator()
|
113 |
+
self.decoder = Conformer(idim=0,
|
114 |
+
attention_dim=adim,
|
115 |
+
attention_heads=aheads,
|
116 |
+
linear_units=dunits,
|
117 |
+
num_blocks=dlayers,
|
118 |
+
input_layer=None,
|
119 |
+
dropout_rate=transformer_dec_dropout_rate,
|
120 |
+
positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
121 |
+
attention_dropout_rate=transformer_dec_attn_dropout_rate,
|
122 |
+
normalize_before=decoder_normalize_before,
|
123 |
+
concat_after=decoder_concat_after,
|
124 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
125 |
+
macaron_style=use_macaron_style_in_conformer,
|
126 |
+
use_cnn_module=use_cnn_in_conformer,
|
127 |
+
cnn_module_kernel=conformer_dec_kernel_size)
|
128 |
+
self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
|
129 |
+
self.postnet = PostNet(idim=idim,
|
130 |
+
odim=odim,
|
131 |
+
n_layers=postnet_layers,
|
132 |
+
n_chans=postnet_chans,
|
133 |
+
n_filts=postnet_filts,
|
134 |
+
use_batch_norm=use_batch_norm,
|
135 |
+
dropout_rate=postnet_dropout_rate)
|
136 |
+
self.load_state_dict(weights)
|
137 |
+
|
138 |
+
def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None,
|
139 |
+
gold_durations=None, gold_pitch=None, gold_energy=None,
|
140 |
+
is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None):
|
141 |
+
# forward encoder
|
142 |
+
text_masks = self._source_mask(text_lens)
|
143 |
+
|
144 |
+
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim)
|
145 |
+
|
146 |
+
# forward duration predictor and variance predictors
|
147 |
+
duration_masks = make_pad_mask(text_lens, device=text_lens.device)
|
148 |
+
|
149 |
+
if self.stop_gradient_from_pitch_predictor:
|
150 |
+
pitch_predictions = self.pitch_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
|
151 |
+
else:
|
152 |
+
pitch_predictions = self.pitch_predictor(encoded_texts, duration_masks.unsqueeze(-1))
|
153 |
+
|
154 |
+
if self.stop_gradient_from_energy_predictor:
|
155 |
+
energy_predictions = self.energy_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
|
156 |
+
else:
|
157 |
+
energy_predictions = self.energy_predictor(encoded_texts, duration_masks.unsqueeze(-1))
|
158 |
+
|
159 |
+
if is_inference:
|
160 |
+
if gold_durations is not None:
|
161 |
+
duration_predictions = gold_durations
|
162 |
+
else:
|
163 |
+
duration_predictions = self.duration_predictor.inference(encoded_texts, duration_masks)
|
164 |
+
if gold_pitch is not None:
|
165 |
+
pitch_predictions = gold_pitch
|
166 |
+
if gold_energy is not None:
|
167 |
+
energy_predictions = gold_energy
|
168 |
+
pitch_embeddings = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
|
169 |
+
energy_embeddings = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
|
170 |
+
encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
|
171 |
+
encoded_texts = self.length_regulator(encoded_texts, duration_predictions, alpha)
|
172 |
+
else:
|
173 |
+
duration_predictions = self.duration_predictor(encoded_texts, duration_masks)
|
174 |
+
|
175 |
+
# use groundtruth in training
|
176 |
+
pitch_embeddings = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
|
177 |
+
energy_embeddings = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
|
178 |
+
encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
|
179 |
+
encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim)
|
180 |
+
|
181 |
+
# forward decoder
|
182 |
+
if speech_lens is not None and not is_inference:
|
183 |
+
if self.reduction_factor > 1:
|
184 |
+
olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens])
|
185 |
+
else:
|
186 |
+
olens_in = speech_lens
|
187 |
+
h_masks = self._source_mask(olens_in)
|
188 |
+
else:
|
189 |
+
h_masks = None
|
190 |
+
zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim)
|
191 |
+
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
|
192 |
+
|
193 |
+
# postnet -> (B, Lmax//r * r, odim)
|
194 |
+
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
|
195 |
+
|
196 |
+
return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions
|
197 |
+
|
198 |
+
@torch.no_grad()
|
199 |
+
def forward(self,
|
200 |
+
text,
|
201 |
+
speech=None,
|
202 |
+
durations=None,
|
203 |
+
pitch=None,
|
204 |
+
energy=None,
|
205 |
+
utterance_embedding=None,
|
206 |
+
return_duration_pitch_energy=False,
|
207 |
+
lang_id=None):
|
208 |
+
"""
|
209 |
+
Generate the sequence of features given the sequences of characters.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
text: Input sequence of characters
|
213 |
+
speech: Feature sequence to extract style
|
214 |
+
durations: Groundtruth of duration
|
215 |
+
pitch: Groundtruth of token-averaged pitch
|
216 |
+
energy: Groundtruth of token-averaged energy
|
217 |
+
return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
|
218 |
+
utterance_embedding: embedding of utterance wide parameters
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
Mel Spectrogram
|
222 |
+
|
223 |
+
"""
|
224 |
+
self.eval()
|
225 |
+
# setup batch axis
|
226 |
+
ilens = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device)
|
227 |
+
if speech is not None:
|
228 |
+
gold_speech = speech.unsqueeze(0)
|
229 |
+
else:
|
230 |
+
gold_speech = None
|
231 |
+
if durations is not None:
|
232 |
+
durations = durations.unsqueeze(0)
|
233 |
+
if pitch is not None:
|
234 |
+
pitch = pitch.unsqueeze(0)
|
235 |
+
if energy is not None:
|
236 |
+
energy = energy.unsqueeze(0)
|
237 |
+
if lang_id is not None:
|
238 |
+
lang_id = lang_id.unsqueeze(0)
|
239 |
+
|
240 |
+
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(text.unsqueeze(0),
|
241 |
+
ilens,
|
242 |
+
gold_speech=gold_speech,
|
243 |
+
gold_durations=durations,
|
244 |
+
is_inference=True,
|
245 |
+
gold_pitch=pitch,
|
246 |
+
gold_energy=energy,
|
247 |
+
utterance_embedding=utterance_embedding.unsqueeze(0),
|
248 |
+
lang_ids=lang_id)
|
249 |
+
self.train()
|
250 |
+
if return_duration_pitch_energy:
|
251 |
+
return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0]
|
252 |
+
return after_outs[0]
|
253 |
+
|
254 |
+
def _source_mask(self, ilens):
|
255 |
+
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
|
256 |
+
return x_masks.unsqueeze(-2)
|
InferenceInterfaces/InferenceArchitectures/InferenceHiFiGAN.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from Layers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock
|
4 |
+
|
5 |
+
|
6 |
+
class HiFiGANGenerator(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self,
|
9 |
+
path_to_weights,
|
10 |
+
in_channels=80,
|
11 |
+
out_channels=1,
|
12 |
+
channels=512,
|
13 |
+
kernel_size=7,
|
14 |
+
upsample_scales=(8, 6, 4, 4),
|
15 |
+
upsample_kernel_sizes=(16, 12, 8, 8),
|
16 |
+
resblock_kernel_sizes=(3, 7, 11),
|
17 |
+
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
|
18 |
+
use_additional_convs=True,
|
19 |
+
bias=True,
|
20 |
+
nonlinear_activation="LeakyReLU",
|
21 |
+
nonlinear_activation_params={"negative_slope": 0.1},
|
22 |
+
use_weight_norm=True, ):
|
23 |
+
super().__init__()
|
24 |
+
assert kernel_size % 2 == 1, "Kernal size must be odd number."
|
25 |
+
assert len(upsample_scales) == len(upsample_kernel_sizes)
|
26 |
+
assert len(resblock_dilations) == len(resblock_kernel_sizes)
|
27 |
+
self.num_upsamples = len(upsample_kernel_sizes)
|
28 |
+
self.num_blocks = len(resblock_kernel_sizes)
|
29 |
+
self.input_conv = torch.nn.Conv1d(in_channels,
|
30 |
+
channels,
|
31 |
+
kernel_size,
|
32 |
+
1,
|
33 |
+
padding=(kernel_size - 1) // 2, )
|
34 |
+
self.upsamples = torch.nn.ModuleList()
|
35 |
+
self.blocks = torch.nn.ModuleList()
|
36 |
+
for i in range(len(upsample_kernel_sizes)):
|
37 |
+
self.upsamples += [
|
38 |
+
torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
39 |
+
torch.nn.ConvTranspose1d(channels // (2 ** i),
|
40 |
+
channels // (2 ** (i + 1)),
|
41 |
+
upsample_kernel_sizes[i],
|
42 |
+
upsample_scales[i],
|
43 |
+
padding=(upsample_kernel_sizes[i] - upsample_scales[i]) // 2, ), )]
|
44 |
+
for j in range(len(resblock_kernel_sizes)):
|
45 |
+
self.blocks += [ResidualBlock(kernel_size=resblock_kernel_sizes[j],
|
46 |
+
channels=channels // (2 ** (i + 1)),
|
47 |
+
dilations=resblock_dilations[j],
|
48 |
+
bias=bias,
|
49 |
+
use_additional_convs=use_additional_convs,
|
50 |
+
nonlinear_activation=nonlinear_activation,
|
51 |
+
nonlinear_activation_params=nonlinear_activation_params, )]
|
52 |
+
self.output_conv = torch.nn.Sequential(
|
53 |
+
torch.nn.LeakyReLU(),
|
54 |
+
torch.nn.Conv1d(channels // (2 ** (i + 1)),
|
55 |
+
out_channels,
|
56 |
+
kernel_size,
|
57 |
+
1,
|
58 |
+
padding=(kernel_size - 1) // 2, ),
|
59 |
+
torch.nn.Tanh(), )
|
60 |
+
if use_weight_norm:
|
61 |
+
self.apply_weight_norm()
|
62 |
+
self.load_state_dict(torch.load(path_to_weights, map_location='cpu')["generator"])
|
63 |
+
|
64 |
+
def forward(self, c, normalize_before=False):
|
65 |
+
if normalize_before:
|
66 |
+
c = (c - self.mean) / self.scale
|
67 |
+
c = self.input_conv(c.unsqueeze(0))
|
68 |
+
for i in range(self.num_upsamples):
|
69 |
+
c = self.upsamples[i](c)
|
70 |
+
cs = 0.0 # initialize
|
71 |
+
for j in range(self.num_blocks):
|
72 |
+
cs = cs + self.blocks[i * self.num_blocks + j](c)
|
73 |
+
c = cs / self.num_blocks
|
74 |
+
c = self.output_conv(c)
|
75 |
+
return c.squeeze(0).squeeze(0)
|
76 |
+
|
77 |
+
def remove_weight_norm(self):
|
78 |
+
def _remove_weight_norm(m):
|
79 |
+
try:
|
80 |
+
torch.nn.utils.remove_weight_norm(m)
|
81 |
+
except ValueError:
|
82 |
+
return
|
83 |
+
|
84 |
+
self.apply(_remove_weight_norm)
|
85 |
+
|
86 |
+
def apply_weight_norm(self):
|
87 |
+
def _apply_weight_norm(m):
|
88 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
89 |
+
torch.nn.utils.weight_norm(m)
|
90 |
+
|
91 |
+
self.apply(_apply_weight_norm)
|
InferenceInterfaces/InferenceArchitectures/__init__.py
ADDED
File without changes
|
InferenceInterfaces/Meta_FastSpeech2.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import librosa.display as lbd
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import soundfile
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from InferenceInterfaces.InferenceArchitectures.InferenceFastSpeech2 import FastSpeech2
|
9 |
+
from InferenceInterfaces.InferenceArchitectures.InferenceHiFiGAN import HiFiGANGenerator
|
10 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
|
11 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
|
12 |
+
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
|
13 |
+
|
14 |
+
|
15 |
+
class Meta_FastSpeech2(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, device="cpu"):
|
18 |
+
super().__init__()
|
19 |
+
model_name = "Meta"
|
20 |
+
language = "en"
|
21 |
+
self.device = device
|
22 |
+
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
|
23 |
+
checkpoint = torch.load(os.path.join("Models", f"FastSpeech2_{model_name}", "best.pt"), map_location='cpu')
|
24 |
+
self.phone2mel = FastSpeech2(weights=checkpoint["model"]).to(torch.device(device))
|
25 |
+
self.mel2wav = HiFiGANGenerator(path_to_weights=os.path.join("Models", "HiFiGAN_combined", "best.pt")).to(torch.device(device))
|
26 |
+
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
|
27 |
+
self.phone2mel.eval()
|
28 |
+
self.mel2wav.eval()
|
29 |
+
self.lang_id = get_language_id(language)
|
30 |
+
self.to(torch.device(device))
|
31 |
+
|
32 |
+
def set_utterance_embedding(self, path_to_reference_audio):
|
33 |
+
wave, sr = soundfile.read(path_to_reference_audio)
|
34 |
+
self.default_utterance_embedding = ProsodicConditionExtractor(sr=sr).extract_condition_from_reference_wave(wave).to(self.device)
|
35 |
+
|
36 |
+
def set_language(self, lang_id):
|
37 |
+
"""
|
38 |
+
The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
|
39 |
+
"""
|
40 |
+
self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, silent=True)
|
41 |
+
self.lang_id = get_language_id(lang_id).to(self.device)
|
42 |
+
|
43 |
+
def forward(self, text, view=False, durations=None, pitch=None, energy=None):
|
44 |
+
with torch.no_grad():
|
45 |
+
phones = self.text2phone.string_to_tensor(text).to(torch.device(self.device))
|
46 |
+
mel, durations, pitch, energy = self.phone2mel(phones,
|
47 |
+
return_duration_pitch_energy=True,
|
48 |
+
utterance_embedding=self.default_utterance_embedding,
|
49 |
+
durations=durations,
|
50 |
+
pitch=pitch,
|
51 |
+
energy=energy)
|
52 |
+
mel = mel.transpose(0, 1)
|
53 |
+
wave = self.mel2wav(mel)
|
54 |
+
if view:
|
55 |
+
from Utility.utils import cumsum_durations
|
56 |
+
fig, ax = plt.subplots(nrows=2, ncols=1)
|
57 |
+
ax[0].plot(wave.cpu().numpy())
|
58 |
+
lbd.specshow(mel.cpu().numpy(),
|
59 |
+
ax=ax[1],
|
60 |
+
sr=16000,
|
61 |
+
cmap='GnBu',
|
62 |
+
y_axis='mel',
|
63 |
+
x_axis=None,
|
64 |
+
hop_length=256)
|
65 |
+
ax[0].yaxis.set_visible(False)
|
66 |
+
ax[1].yaxis.set_visible(False)
|
67 |
+
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
|
68 |
+
ax[1].set_xticks(duration_splits, minor=True)
|
69 |
+
ax[1].xaxis.grid(True, which='minor')
|
70 |
+
ax[1].set_xticks(label_positions, minor=False)
|
71 |
+
ax[1].set_xticklabels(self.text2phone.get_phone_string(text))
|
72 |
+
ax[0].set_title(text)
|
73 |
+
plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.9, wspace=0.0, hspace=0.0)
|
74 |
+
plt.show()
|
75 |
+
return wave
|
InferenceInterfaces/__init__.py
ADDED
File without changes
|
Layers/Attention.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
"""Multi-Head Attention layer definition."""
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from Utility.utils import make_non_pad_mask
|
14 |
+
|
15 |
+
|
16 |
+
class MultiHeadedAttention(nn.Module):
|
17 |
+
"""
|
18 |
+
Multi-Head Attention layer.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
n_head (int): The number of heads.
|
22 |
+
n_feat (int): The number of features.
|
23 |
+
dropout_rate (float): Dropout rate.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, n_head, n_feat, dropout_rate):
|
27 |
+
"""
|
28 |
+
Construct an MultiHeadedAttention object.
|
29 |
+
"""
|
30 |
+
super(MultiHeadedAttention, self).__init__()
|
31 |
+
assert n_feat % n_head == 0
|
32 |
+
# We assume d_v always equals d_k
|
33 |
+
self.d_k = n_feat // n_head
|
34 |
+
self.h = n_head
|
35 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
36 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
37 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
38 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
39 |
+
self.attn = None
|
40 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
41 |
+
|
42 |
+
def forward_qkv(self, query, key, value):
|
43 |
+
"""
|
44 |
+
Transform query, key and value.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
48 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
49 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
53 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
54 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
55 |
+
"""
|
56 |
+
n_batch = query.size(0)
|
57 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
58 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
59 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
60 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
61 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
62 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
63 |
+
|
64 |
+
return q, k, v
|
65 |
+
|
66 |
+
def forward_attention(self, value, scores, mask):
|
67 |
+
"""
|
68 |
+
Compute attention context vector.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
72 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
73 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
77 |
+
weighted by the attention score (#batch, time1, time2).
|
78 |
+
"""
|
79 |
+
n_batch = value.size(0)
|
80 |
+
if mask is not None:
|
81 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
82 |
+
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
83 |
+
scores = scores.masked_fill(mask, min_value)
|
84 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
85 |
+
else:
|
86 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
87 |
+
|
88 |
+
p_attn = self.dropout(self.attn)
|
89 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
90 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model)
|
91 |
+
|
92 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
93 |
+
|
94 |
+
def forward(self, query, key, value, mask):
|
95 |
+
"""
|
96 |
+
Compute scaled dot product attention.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
100 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
101 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
102 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
103 |
+
(#batch, time1, time2).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
107 |
+
"""
|
108 |
+
q, k, v = self.forward_qkv(query, key, value)
|
109 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
110 |
+
return self.forward_attention(v, scores, mask)
|
111 |
+
|
112 |
+
|
113 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
114 |
+
"""
|
115 |
+
Multi-Head Attention layer with relative position encoding.
|
116 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
117 |
+
Paper: https://arxiv.org/abs/1901.02860
|
118 |
+
Args:
|
119 |
+
n_head (int): The number of heads.
|
120 |
+
n_feat (int): The number of features.
|
121 |
+
dropout_rate (float): Dropout rate.
|
122 |
+
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
126 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
127 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
128 |
+
self.zero_triu = zero_triu
|
129 |
+
# linear transformation for positional encoding
|
130 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
131 |
+
# these two learnable bias are used in matrix c and matrix d
|
132 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
133 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
134 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
135 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
136 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
137 |
+
|
138 |
+
def rel_shift(self, x):
|
139 |
+
"""
|
140 |
+
Compute relative positional encoding.
|
141 |
+
Args:
|
142 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
143 |
+
time1 means the length of query vector.
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: Output tensor.
|
146 |
+
"""
|
147 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
148 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
149 |
+
|
150 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
151 |
+
x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] # only keep the positions from 0 to time2
|
152 |
+
|
153 |
+
if self.zero_triu:
|
154 |
+
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
155 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
156 |
+
|
157 |
+
return x
|
158 |
+
|
159 |
+
def forward(self, query, key, value, pos_emb, mask):
|
160 |
+
"""
|
161 |
+
Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
162 |
+
Args:
|
163 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
164 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
165 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
166 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
167 |
+
(#batch, 2*time1-1, size).
|
168 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
169 |
+
(#batch, time1, time2).
|
170 |
+
Returns:
|
171 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
172 |
+
"""
|
173 |
+
q, k, v = self.forward_qkv(query, key, value)
|
174 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
175 |
+
|
176 |
+
n_batch_pos = pos_emb.size(0)
|
177 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
178 |
+
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
179 |
+
|
180 |
+
# (batch, head, time1, d_k)
|
181 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
182 |
+
# (batch, head, time1, d_k)
|
183 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
184 |
+
|
185 |
+
# compute attention score
|
186 |
+
# first compute matrix a and matrix c
|
187 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
188 |
+
# (batch, head, time1, time2)
|
189 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
190 |
+
|
191 |
+
# compute matrix b and matrix d
|
192 |
+
# (batch, head, time1, 2*time1-1)
|
193 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
194 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
195 |
+
|
196 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
|
197 |
+
|
198 |
+
return self.forward_attention(v, scores, mask)
|
199 |
+
|
200 |
+
|
201 |
+
class GuidedAttentionLoss(torch.nn.Module):
|
202 |
+
"""
|
203 |
+
Guided attention loss function module.
|
204 |
+
|
205 |
+
This module calculates the guided attention loss described
|
206 |
+
in `Efficiently Trainable Text-to-Speech System Based
|
207 |
+
on Deep Convolutional Networks with Guided Attention`_,
|
208 |
+
which forces the attention to be diagonal.
|
209 |
+
|
210 |
+
.. _`Efficiently Trainable Text-to-Speech System
|
211 |
+
Based on Deep Convolutional Networks with Guided Attention`:
|
212 |
+
https://arxiv.org/abs/1710.08969
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, sigma=0.4, alpha=1.0):
|
216 |
+
"""
|
217 |
+
Initialize guided attention loss module.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
sigma (float, optional): Standard deviation to control
|
221 |
+
how close attention to a diagonal.
|
222 |
+
alpha (float, optional): Scaling coefficient (lambda).
|
223 |
+
reset_always (bool, optional): Whether to always reset masks.
|
224 |
+
"""
|
225 |
+
super(GuidedAttentionLoss, self).__init__()
|
226 |
+
self.sigma = sigma
|
227 |
+
self.alpha = alpha
|
228 |
+
self.guided_attn_masks = None
|
229 |
+
self.masks = None
|
230 |
+
|
231 |
+
def _reset_masks(self):
|
232 |
+
self.guided_attn_masks = None
|
233 |
+
self.masks = None
|
234 |
+
|
235 |
+
def forward(self, att_ws, ilens, olens):
|
236 |
+
"""
|
237 |
+
Calculate forward propagation.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
|
241 |
+
ilens (LongTensor): Batch of input lenghts (B,).
|
242 |
+
olens (LongTensor): Batch of output lenghts (B,).
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
Tensor: Guided attention loss value.
|
246 |
+
"""
|
247 |
+
self._reset_masks()
|
248 |
+
self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device)
|
249 |
+
self.masks = self._make_masks(ilens, olens).to(att_ws.device)
|
250 |
+
losses = self.guided_attn_masks * att_ws
|
251 |
+
loss = torch.mean(losses.masked_select(self.masks))
|
252 |
+
self._reset_masks()
|
253 |
+
return self.alpha * loss
|
254 |
+
|
255 |
+
def _make_guided_attention_masks(self, ilens, olens):
|
256 |
+
n_batches = len(ilens)
|
257 |
+
max_ilen = max(ilens)
|
258 |
+
max_olen = max(olens)
|
259 |
+
guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=ilens.device)
|
260 |
+
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
|
261 |
+
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma)
|
262 |
+
return guided_attn_masks
|
263 |
+
|
264 |
+
@staticmethod
|
265 |
+
def _make_guided_attention_mask(ilen, olen, sigma):
|
266 |
+
"""
|
267 |
+
Make guided attention mask.
|
268 |
+
"""
|
269 |
+
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device).float(), torch.arange(ilen, device=ilen.device).float())
|
270 |
+
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
|
271 |
+
|
272 |
+
@staticmethod
|
273 |
+
def _make_masks(ilens, olens):
|
274 |
+
"""
|
275 |
+
Make masks indicating non-padded part.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
ilens (LongTensor or List): Batch of lengths (B,).
|
279 |
+
olens (LongTensor or List): Batch of lengths (B,).
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Tensor: Mask tensor indicating non-padded part.
|
283 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
284 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
285 |
+
"""
|
286 |
+
in_masks = make_non_pad_mask(ilens, device=ilens.device) # (B, T_in)
|
287 |
+
out_masks = make_non_pad_mask(olens, device=olens.device) # (B, T_out)
|
288 |
+
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
|
289 |
+
|
290 |
+
|
291 |
+
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
|
292 |
+
"""
|
293 |
+
Guided attention loss function module for multi head attention.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
sigma (float, optional): Standard deviation to control
|
297 |
+
how close attention to a diagonal.
|
298 |
+
alpha (float, optional): Scaling coefficient (lambda).
|
299 |
+
reset_always (bool, optional): Whether to always reset masks.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def forward(self, att_ws, ilens, olens):
|
303 |
+
"""
|
304 |
+
Calculate forward propagation.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
att_ws (Tensor):
|
308 |
+
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
|
309 |
+
ilens (LongTensor): Batch of input lenghts (B,).
|
310 |
+
olens (LongTensor): Batch of output lenghts (B,).
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
Tensor: Guided attention loss value.
|
314 |
+
"""
|
315 |
+
if self.guided_attn_masks is None:
|
316 |
+
self.guided_attn_masks = (self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1))
|
317 |
+
if self.masks is None:
|
318 |
+
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
|
319 |
+
losses = self.guided_attn_masks * att_ws
|
320 |
+
loss = torch.mean(losses.masked_select(self.masks))
|
321 |
+
if self.reset_always:
|
322 |
+
self._reset_masks()
|
323 |
+
|
324 |
+
return self.alpha * loss
|
Layers/Conformer.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from Layers.Attention import RelPositionMultiHeadedAttention
|
9 |
+
from Layers.Convolution import ConvolutionModule
|
10 |
+
from Layers.EncoderLayer import EncoderLayer
|
11 |
+
from Layers.LayerNorm import LayerNorm
|
12 |
+
from Layers.MultiLayeredConv1d import MultiLayeredConv1d
|
13 |
+
from Layers.MultiSequential import repeat
|
14 |
+
from Layers.PositionalEncoding import RelPositionalEncoding
|
15 |
+
from Layers.Swish import Swish
|
16 |
+
|
17 |
+
|
18 |
+
class Conformer(torch.nn.Module):
|
19 |
+
"""
|
20 |
+
Conformer encoder module.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
idim (int): Input dimension.
|
24 |
+
attention_dim (int): Dimension of attention.
|
25 |
+
attention_heads (int): The number of heads of multi head attention.
|
26 |
+
linear_units (int): The number of units of position-wise feed forward.
|
27 |
+
num_blocks (int): The number of decoder blocks.
|
28 |
+
dropout_rate (float): Dropout rate.
|
29 |
+
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
30 |
+
attention_dropout_rate (float): Dropout rate in attention.
|
31 |
+
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
32 |
+
normalize_before (bool): Whether to use layer_norm before the first block.
|
33 |
+
concat_after (bool): Whether to concat attention layer's input and output.
|
34 |
+
if True, additional linear will be applied.
|
35 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
36 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
37 |
+
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
38 |
+
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
39 |
+
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
40 |
+
pos_enc_layer_type (str): Conformer positional encoding layer type.
|
41 |
+
selfattention_layer_type (str): Conformer attention layer type.
|
42 |
+
activation_type (str): Conformer activation function type.
|
43 |
+
use_cnn_module (bool): Whether to use convolution module.
|
44 |
+
cnn_module_kernel (int): Kernerl size of convolution module.
|
45 |
+
padding_idx (int): Padding idx for input_layer=embed.
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1,
|
50 |
+
attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1,
|
51 |
+
macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, connect_utt_emb_at_encoder_out=True,
|
52 |
+
spk_emb_bottleneck_size=128, lang_embs=None):
|
53 |
+
super(Conformer, self).__init__()
|
54 |
+
|
55 |
+
activation = Swish()
|
56 |
+
self.conv_subsampling_factor = 1
|
57 |
+
|
58 |
+
if isinstance(input_layer, torch.nn.Module):
|
59 |
+
self.embed = input_layer
|
60 |
+
self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate)
|
61 |
+
elif input_layer is None:
|
62 |
+
self.embed = None
|
63 |
+
self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
|
64 |
+
else:
|
65 |
+
raise ValueError("unknown input_layer: " + input_layer)
|
66 |
+
|
67 |
+
self.normalize_before = normalize_before
|
68 |
+
|
69 |
+
self.connect_utt_emb_at_encoder_out = connect_utt_emb_at_encoder_out
|
70 |
+
if utt_embed is not None:
|
71 |
+
self.hs_emb_projection = torch.nn.Linear(attention_dim + spk_emb_bottleneck_size, attention_dim)
|
72 |
+
# embedding projection derived from https://arxiv.org/pdf/1705.08947.pdf
|
73 |
+
self.embedding_projection = torch.nn.Sequential(torch.nn.Linear(utt_embed, spk_emb_bottleneck_size),
|
74 |
+
torch.nn.Softsign())
|
75 |
+
if lang_embs is not None:
|
76 |
+
self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=attention_dim)
|
77 |
+
|
78 |
+
# self-attention module definition
|
79 |
+
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
80 |
+
encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
|
81 |
+
|
82 |
+
# feed-forward module definition
|
83 |
+
positionwise_layer = MultiLayeredConv1d
|
84 |
+
positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,)
|
85 |
+
|
86 |
+
# convolution module definition
|
87 |
+
convolution_layer = ConvolutionModule
|
88 |
+
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
89 |
+
|
90 |
+
self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
91 |
+
positionwise_layer(*positionwise_layer_args),
|
92 |
+
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
93 |
+
convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate,
|
94 |
+
normalize_before, concat_after))
|
95 |
+
if self.normalize_before:
|
96 |
+
self.after_norm = LayerNorm(attention_dim)
|
97 |
+
|
98 |
+
def forward(self, xs, masks, utterance_embedding=None, lang_ids=None):
|
99 |
+
"""
|
100 |
+
Encode input sequence.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
utterance_embedding: embedding containing lots of conditioning signals
|
104 |
+
step: indicator for when to start updating the embedding function
|
105 |
+
xs (torch.Tensor): Input tensor (#batch, time, idim).
|
106 |
+
masks (torch.Tensor): Mask tensor (#batch, time).
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.Tensor: Output tensor (#batch, time, attention_dim).
|
110 |
+
torch.Tensor: Mask tensor (#batch, time).
|
111 |
+
|
112 |
+
"""
|
113 |
+
|
114 |
+
if self.embed is not None:
|
115 |
+
xs = self.embed(xs)
|
116 |
+
|
117 |
+
if lang_ids is not None:
|
118 |
+
lang_embs = self.language_embedding(lang_ids)
|
119 |
+
xs = xs + lang_embs # offset the phoneme distribution of a language
|
120 |
+
|
121 |
+
if utterance_embedding is not None and not self.connect_utt_emb_at_encoder_out:
|
122 |
+
xs = self._integrate_with_utt_embed(xs, utterance_embedding)
|
123 |
+
|
124 |
+
xs = self.pos_enc(xs)
|
125 |
+
|
126 |
+
xs, masks = self.encoders(xs, masks)
|
127 |
+
if isinstance(xs, tuple):
|
128 |
+
xs = xs[0]
|
129 |
+
|
130 |
+
if self.normalize_before:
|
131 |
+
xs = self.after_norm(xs)
|
132 |
+
|
133 |
+
if utterance_embedding is not None and self.connect_utt_emb_at_encoder_out:
|
134 |
+
xs = self._integrate_with_utt_embed(xs, utterance_embedding)
|
135 |
+
|
136 |
+
return xs, masks
|
137 |
+
|
138 |
+
def _integrate_with_utt_embed(self, hs, utt_embeddings):
|
139 |
+
# project embedding into smaller space
|
140 |
+
speaker_embeddings_projected = self.embedding_projection(utt_embeddings)
|
141 |
+
# concat hidden states with spk embeds and then apply projection
|
142 |
+
speaker_embeddings_expanded = F.normalize(speaker_embeddings_projected).unsqueeze(1).expand(-1, hs.size(1), -1)
|
143 |
+
hs = self.hs_emb_projection(torch.cat([hs, speaker_embeddings_expanded], dim=-1))
|
144 |
+
return hs
|
Layers/Convolution.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
4 |
+
# Adapted by Florian Lux 2021
|
5 |
+
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class ConvolutionModule(nn.Module):
|
11 |
+
"""
|
12 |
+
ConvolutionModule in Conformer model.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
channels (int): The number of channels of conv layers.
|
16 |
+
kernel_size (int): Kernel size of conv layers.
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
21 |
+
super(ConvolutionModule, self).__init__()
|
22 |
+
# kernel_size should be an odd number for 'SAME' padding
|
23 |
+
assert (kernel_size - 1) % 2 == 0
|
24 |
+
|
25 |
+
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
26 |
+
self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
|
27 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=channels)
|
28 |
+
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
|
29 |
+
self.activation = activation
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
"""
|
33 |
+
Compute convolution module.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
40 |
+
|
41 |
+
"""
|
42 |
+
# exchange the temporal dimension and the feature dimension
|
43 |
+
x = x.transpose(1, 2)
|
44 |
+
|
45 |
+
# GLU mechanism
|
46 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
47 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
48 |
+
|
49 |
+
# 1D Depthwise Conv
|
50 |
+
x = self.depthwise_conv(x)
|
51 |
+
x = self.activation(self.norm(x))
|
52 |
+
|
53 |
+
x = self.pointwise_conv2(x)
|
54 |
+
|
55 |
+
return x.transpose(1, 2)
|
Layers/DurationPredictor.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from Layers.LayerNorm import LayerNorm
|
9 |
+
|
10 |
+
|
11 |
+
class DurationPredictor(torch.nn.Module):
|
12 |
+
"""
|
13 |
+
Duration predictor module.
|
14 |
+
|
15 |
+
This is a module of duration predictor described
|
16 |
+
in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
17 |
+
The duration predictor predicts a duration of each frame in log domain
|
18 |
+
from the hidden embeddings of encoder.
|
19 |
+
|
20 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
21 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
22 |
+
|
23 |
+
Note:
|
24 |
+
The calculation domain of outputs is different
|
25 |
+
between in `forward` and in `inference`. In `forward`,
|
26 |
+
the outputs are calculated in log domain but in `inference`,
|
27 |
+
those are calculated in linear domain.
|
28 |
+
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
|
32 |
+
"""
|
33 |
+
Initialize duration predictor module.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
idim (int): Input dimension.
|
37 |
+
n_layers (int, optional): Number of convolutional layers.
|
38 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
39 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
40 |
+
dropout_rate (float, optional): Dropout rate.
|
41 |
+
offset (float, optional): Offset value to avoid nan in log domain.
|
42 |
+
|
43 |
+
"""
|
44 |
+
super(DurationPredictor, self).__init__()
|
45 |
+
self.offset = offset
|
46 |
+
self.conv = torch.nn.ModuleList()
|
47 |
+
for idx in range(n_layers):
|
48 |
+
in_chans = idim if idx == 0 else n_chans
|
49 |
+
self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), torch.nn.ReLU(),
|
50 |
+
LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), )]
|
51 |
+
self.linear = torch.nn.Linear(n_chans, 1)
|
52 |
+
|
53 |
+
def _forward(self, xs, x_masks=None, is_inference=False):
|
54 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
55 |
+
for f in self.conv:
|
56 |
+
xs = f(xs) # (B, C, Tmax)
|
57 |
+
|
58 |
+
# NOTE: calculate in log domain
|
59 |
+
xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
|
60 |
+
|
61 |
+
if is_inference:
|
62 |
+
# NOTE: calculate in linear domain
|
63 |
+
xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
|
64 |
+
|
65 |
+
if x_masks is not None:
|
66 |
+
xs = xs.masked_fill(x_masks, 0.0)
|
67 |
+
|
68 |
+
return xs
|
69 |
+
|
70 |
+
def forward(self, xs, x_masks=None):
|
71 |
+
"""
|
72 |
+
Calculate forward propagation.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
76 |
+
x_masks (ByteTensor, optional):
|
77 |
+
Batch of masks indicating padded part (B, Tmax).
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Tensor: Batch of predicted durations in log domain (B, Tmax).
|
81 |
+
|
82 |
+
"""
|
83 |
+
return self._forward(xs, x_masks, False)
|
84 |
+
|
85 |
+
def inference(self, xs, x_masks=None):
|
86 |
+
"""
|
87 |
+
Inference duration.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
91 |
+
x_masks (ByteTensor, optional):
|
92 |
+
Batch of masks indicating padded part (B, Tmax).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
|
96 |
+
|
97 |
+
"""
|
98 |
+
return self._forward(xs, x_masks, True)
|
99 |
+
|
100 |
+
|
101 |
+
class DurationPredictorLoss(torch.nn.Module):
|
102 |
+
"""
|
103 |
+
Loss function module for duration predictor.
|
104 |
+
|
105 |
+
The loss value is Calculated in log domain to make it Gaussian.
|
106 |
+
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, offset=1.0, reduction="mean"):
|
110 |
+
"""
|
111 |
+
Args:
|
112 |
+
offset (float, optional): Offset value to avoid nan in log domain.
|
113 |
+
reduction (str): Reduction type in loss calculation.
|
114 |
+
|
115 |
+
"""
|
116 |
+
super(DurationPredictorLoss, self).__init__()
|
117 |
+
self.criterion = torch.nn.MSELoss(reduction=reduction)
|
118 |
+
self.offset = offset
|
119 |
+
|
120 |
+
def forward(self, outputs, targets):
|
121 |
+
"""
|
122 |
+
Calculate forward propagation.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
outputs (Tensor): Batch of prediction durations in log domain (B, T)
|
126 |
+
targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
Tensor: Mean squared error loss value.
|
130 |
+
|
131 |
+
Note:
|
132 |
+
`outputs` is in log domain but `targets` is in linear domain.
|
133 |
+
|
134 |
+
"""
|
135 |
+
# NOTE: outputs is in log domain while targets in linear
|
136 |
+
targets = torch.log(targets.float() + self.offset)
|
137 |
+
loss = self.criterion(outputs, targets)
|
138 |
+
|
139 |
+
return loss
|
Layers/EncoderLayer.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
4 |
+
# Adapted by Florian Lux 2021
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from Layers.LayerNorm import LayerNorm
|
11 |
+
|
12 |
+
|
13 |
+
class EncoderLayer(nn.Module):
|
14 |
+
"""
|
15 |
+
Encoder layer module.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
size (int): Input dimension.
|
19 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
20 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
21 |
+
can be used as the argument.
|
22 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
23 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
24 |
+
can be used as the argument.
|
25 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
26 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
27 |
+
can be used as the argument.
|
28 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
29 |
+
`ConvlutionModule` instance can be used as the argument.
|
30 |
+
dropout_rate (float): Dropout rate.
|
31 |
+
normalize_before (bool): Whether to use layer_norm before the first block.
|
32 |
+
concat_after (bool): Whether to concat attention layer's input and output.
|
33 |
+
if True, additional linear will be applied.
|
34 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
35 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ):
|
40 |
+
super(EncoderLayer, self).__init__()
|
41 |
+
self.self_attn = self_attn
|
42 |
+
self.feed_forward = feed_forward
|
43 |
+
self.feed_forward_macaron = feed_forward_macaron
|
44 |
+
self.conv_module = conv_module
|
45 |
+
self.norm_ff = LayerNorm(size) # for the FNN module
|
46 |
+
self.norm_mha = LayerNorm(size) # for the MHA module
|
47 |
+
if feed_forward_macaron is not None:
|
48 |
+
self.norm_ff_macaron = LayerNorm(size)
|
49 |
+
self.ff_scale = 0.5
|
50 |
+
else:
|
51 |
+
self.ff_scale = 1.0
|
52 |
+
if self.conv_module is not None:
|
53 |
+
self.norm_conv = LayerNorm(size) # for the CNN module
|
54 |
+
self.norm_final = LayerNorm(size) # for the final output of the block
|
55 |
+
self.dropout = nn.Dropout(dropout_rate)
|
56 |
+
self.size = size
|
57 |
+
self.normalize_before = normalize_before
|
58 |
+
self.concat_after = concat_after
|
59 |
+
if self.concat_after:
|
60 |
+
self.concat_linear = nn.Linear(size + size, size)
|
61 |
+
|
62 |
+
def forward(self, x_input, mask, cache=None):
|
63 |
+
"""
|
64 |
+
Compute encoded features.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
68 |
+
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
69 |
+
- w/o pos emb: Tensor (#batch, time, size).
|
70 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
71 |
+
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
75 |
+
torch.Tensor: Mask tensor (#batch, time).
|
76 |
+
|
77 |
+
"""
|
78 |
+
if isinstance(x_input, tuple):
|
79 |
+
x, pos_emb = x_input[0], x_input[1]
|
80 |
+
else:
|
81 |
+
x, pos_emb = x_input, None
|
82 |
+
|
83 |
+
# whether to use macaron style
|
84 |
+
if self.feed_forward_macaron is not None:
|
85 |
+
residual = x
|
86 |
+
if self.normalize_before:
|
87 |
+
x = self.norm_ff_macaron(x)
|
88 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
89 |
+
if not self.normalize_before:
|
90 |
+
x = self.norm_ff_macaron(x)
|
91 |
+
|
92 |
+
# multi-headed self-attention module
|
93 |
+
residual = x
|
94 |
+
if self.normalize_before:
|
95 |
+
x = self.norm_mha(x)
|
96 |
+
|
97 |
+
if cache is None:
|
98 |
+
x_q = x
|
99 |
+
else:
|
100 |
+
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
101 |
+
x_q = x[:, -1:, :]
|
102 |
+
residual = residual[:, -1:, :]
|
103 |
+
mask = None if mask is None else mask[:, -1:, :]
|
104 |
+
|
105 |
+
if pos_emb is not None:
|
106 |
+
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
107 |
+
else:
|
108 |
+
x_att = self.self_attn(x_q, x, x, mask)
|
109 |
+
|
110 |
+
if self.concat_after:
|
111 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
112 |
+
x = residual + self.concat_linear(x_concat)
|
113 |
+
else:
|
114 |
+
x = residual + self.dropout(x_att)
|
115 |
+
if not self.normalize_before:
|
116 |
+
x = self.norm_mha(x)
|
117 |
+
|
118 |
+
# convolution module
|
119 |
+
if self.conv_module is not None:
|
120 |
+
residual = x
|
121 |
+
if self.normalize_before:
|
122 |
+
x = self.norm_conv(x)
|
123 |
+
x = residual + self.dropout(self.conv_module(x))
|
124 |
+
if not self.normalize_before:
|
125 |
+
x = self.norm_conv(x)
|
126 |
+
|
127 |
+
# feed forward module
|
128 |
+
residual = x
|
129 |
+
if self.normalize_before:
|
130 |
+
x = self.norm_ff(x)
|
131 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
132 |
+
if not self.normalize_before:
|
133 |
+
x = self.norm_ff(x)
|
134 |
+
|
135 |
+
if self.conv_module is not None:
|
136 |
+
x = self.norm_final(x)
|
137 |
+
|
138 |
+
if cache is not None:
|
139 |
+
x = torch.cat([cache, x], dim=1)
|
140 |
+
|
141 |
+
if pos_emb is not None:
|
142 |
+
return (x, pos_emb), mask
|
143 |
+
|
144 |
+
return x, mask
|
Layers/LayerNorm.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class LayerNorm(torch.nn.LayerNorm):
|
9 |
+
"""
|
10 |
+
Layer normalization module.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
nout (int): Output dim size.
|
14 |
+
dim (int): Dimension to be normalized.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, nout, dim=-1):
|
18 |
+
"""
|
19 |
+
Construct an LayerNorm object.
|
20 |
+
"""
|
21 |
+
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
22 |
+
self.dim = dim
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
"""
|
26 |
+
Apply layer normalization.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
x (torch.Tensor): Input tensor.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
torch.Tensor: Normalized tensor.
|
33 |
+
"""
|
34 |
+
if self.dim == -1:
|
35 |
+
return super(LayerNorm, self).forward(x)
|
36 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
Layers/LengthRegulator.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
from abc import ABC
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from Utility.utils import pad_list
|
10 |
+
|
11 |
+
|
12 |
+
class LengthRegulator(torch.nn.Module, ABC):
|
13 |
+
"""
|
14 |
+
Length regulator module for feed-forward Transformer.
|
15 |
+
|
16 |
+
This is a module of length regulator described in
|
17 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
18 |
+
The length regulator expands char or
|
19 |
+
phoneme-level embedding features to frame-level by repeating each
|
20 |
+
feature based on the corresponding predicted durations.
|
21 |
+
|
22 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
23 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
24 |
+
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, pad_value=0.0):
|
28 |
+
"""
|
29 |
+
Initialize length regulator module.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
pad_value (float, optional): Value used for padding.
|
33 |
+
"""
|
34 |
+
super(LengthRegulator, self).__init__()
|
35 |
+
self.pad_value = pad_value
|
36 |
+
|
37 |
+
def forward(self, xs, ds, alpha=1.0):
|
38 |
+
"""
|
39 |
+
Calculate forward propagation.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
|
43 |
+
ds (LongTensor): Batch of durations of each frame (B, T).
|
44 |
+
alpha (float, optional): Alpha value to control speed of speech.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Tensor: replicated input tensor based on durations (B, T*, D).
|
48 |
+
"""
|
49 |
+
if alpha != 1.0:
|
50 |
+
assert alpha > 0
|
51 |
+
ds = torch.round(ds.float() * alpha).long()
|
52 |
+
|
53 |
+
if ds.sum() == 0:
|
54 |
+
ds[ds.sum(dim=1).eq(0)] = 1
|
55 |
+
|
56 |
+
return pad_list([self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)], self.pad_value)
|
57 |
+
|
58 |
+
def _repeat_one_sequence(self, x, d):
|
59 |
+
"""
|
60 |
+
Repeat each frame according to duration
|
61 |
+
"""
|
62 |
+
return torch.repeat_interleave(x, d, dim=0)
|
Layers/MultiLayeredConv1d.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
"""
|
6 |
+
Layer modules for FFT block in FastSpeech (Feed-forward Transformer).
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class MultiLayeredConv1d(torch.nn.Module):
|
13 |
+
"""
|
14 |
+
Multi-layered conv1d for Transformer block.
|
15 |
+
|
16 |
+
This is a module of multi-layered conv1d designed
|
17 |
+
to replace positionwise feed-forward network
|
18 |
+
in Transformer block, which is introduced in
|
19 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
20 |
+
|
21 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
22 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
26 |
+
"""
|
27 |
+
Initialize MultiLayeredConv1d module.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
in_chans (int): Number of input channels.
|
31 |
+
hidden_chans (int): Number of hidden channels.
|
32 |
+
kernel_size (int): Kernel size of conv1d.
|
33 |
+
dropout_rate (float): Dropout rate.
|
34 |
+
"""
|
35 |
+
super(MultiLayeredConv1d, self).__init__()
|
36 |
+
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
|
37 |
+
self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
|
38 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
"""
|
42 |
+
Calculate forward propagation.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
49 |
+
"""
|
50 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
51 |
+
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
52 |
+
|
53 |
+
|
54 |
+
class Conv1dLinear(torch.nn.Module):
|
55 |
+
"""
|
56 |
+
Conv1D + Linear for Transformer block.
|
57 |
+
|
58 |
+
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
62 |
+
"""
|
63 |
+
Initialize Conv1dLinear module.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
in_chans (int): Number of input channels.
|
67 |
+
hidden_chans (int): Number of hidden channels.
|
68 |
+
kernel_size (int): Kernel size of conv1d.
|
69 |
+
dropout_rate (float): Dropout rate.
|
70 |
+
"""
|
71 |
+
super(Conv1dLinear, self).__init__()
|
72 |
+
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
|
73 |
+
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
74 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""
|
78 |
+
Calculate forward propagation.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
85 |
+
"""
|
86 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
87 |
+
return self.w_2(self.dropout(x))
|
Layers/MultiSequential.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class MultiSequential(torch.nn.Sequential):
|
9 |
+
"""
|
10 |
+
Multi-input multi-output torch.nn.Sequential.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def forward(self, *args):
|
14 |
+
"""
|
15 |
+
Repeat.
|
16 |
+
"""
|
17 |
+
for m in self:
|
18 |
+
args = m(*args)
|
19 |
+
return args
|
20 |
+
|
21 |
+
|
22 |
+
def repeat(N, fn):
|
23 |
+
"""
|
24 |
+
Repeat module N times.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
N (int): Number of repeat time.
|
28 |
+
fn (Callable): Function to generate module.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
MultiSequential: Repeated model instance.
|
32 |
+
"""
|
33 |
+
return MultiSequential(*[fn(n) for n in range(N)])
|
Layers/PositionalEncoding.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class PositionalEncoding(torch.nn.Module):
|
11 |
+
"""
|
12 |
+
Positional encoding.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
d_model (int): Embedding dimension.
|
16 |
+
dropout_rate (float): Dropout rate.
|
17 |
+
max_len (int): Maximum input length.
|
18 |
+
reverse (bool): Whether to reverse the input position.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
22 |
+
"""
|
23 |
+
Construct an PositionalEncoding object.
|
24 |
+
"""
|
25 |
+
super(PositionalEncoding, self).__init__()
|
26 |
+
self.d_model = d_model
|
27 |
+
self.reverse = reverse
|
28 |
+
self.xscale = math.sqrt(self.d_model)
|
29 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
30 |
+
self.pe = None
|
31 |
+
self.extend_pe(torch.tensor(0.0, device=d_model.device).expand(1, max_len))
|
32 |
+
|
33 |
+
def extend_pe(self, x):
|
34 |
+
"""
|
35 |
+
Reset the positional encodings.
|
36 |
+
"""
|
37 |
+
if self.pe is not None:
|
38 |
+
if self.pe.size(1) >= x.size(1):
|
39 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
40 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
41 |
+
return
|
42 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
43 |
+
if self.reverse:
|
44 |
+
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
45 |
+
else:
|
46 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
47 |
+
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model))
|
48 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
49 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
50 |
+
pe = pe.unsqueeze(0)
|
51 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
"""
|
55 |
+
Add positional encoding.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
62 |
+
"""
|
63 |
+
self.extend_pe(x)
|
64 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
65 |
+
return self.dropout(x)
|
66 |
+
|
67 |
+
|
68 |
+
class RelPositionalEncoding(torch.nn.Module):
|
69 |
+
"""
|
70 |
+
Relative positional encoding module (new implementation).
|
71 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
72 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
73 |
+
Args:
|
74 |
+
d_model (int): Embedding dimension.
|
75 |
+
dropout_rate (float): Dropout rate.
|
76 |
+
max_len (int): Maximum input length.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
80 |
+
"""
|
81 |
+
Construct an PositionalEncoding object.
|
82 |
+
"""
|
83 |
+
super(RelPositionalEncoding, self).__init__()
|
84 |
+
self.d_model = d_model
|
85 |
+
self.xscale = math.sqrt(self.d_model)
|
86 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
87 |
+
self.pe = None
|
88 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
89 |
+
|
90 |
+
def extend_pe(self, x):
|
91 |
+
"""Reset the positional encodings."""
|
92 |
+
if self.pe is not None:
|
93 |
+
# self.pe contains both positive and negative parts
|
94 |
+
# the length of self.pe is 2 * input_len - 1
|
95 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
96 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
97 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
98 |
+
return
|
99 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
100 |
+
# position of key vector. We use position relative positions when keys
|
101 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
102 |
+
pe_positive = torch.zeros(x.size(1), self.d_model, device=x.device)
|
103 |
+
pe_negative = torch.zeros(x.size(1), self.d_model, device=x.device)
|
104 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32, device=x.device).unsqueeze(1)
|
105 |
+
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32, device=x.device) * -(math.log(10000.0) / self.d_model))
|
106 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
107 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
108 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
109 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
110 |
+
|
111 |
+
# Reserve the order of positive indices and concat both positive and
|
112 |
+
# negative indices. This is used to support the shifting trick
|
113 |
+
# as in https://arxiv.org/abs/1901.02860
|
114 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
115 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
116 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
117 |
+
self.pe = pe.to(dtype=x.dtype)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
"""
|
121 |
+
Add positional encoding.
|
122 |
+
Args:
|
123 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
124 |
+
Returns:
|
125 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
126 |
+
"""
|
127 |
+
self.extend_pe(x)
|
128 |
+
x = x * self.xscale
|
129 |
+
pos_emb = self.pe[:, self.pe.size(1) // 2 - x.size(1) + 1: self.pe.size(1) // 2 + x.size(1), ]
|
130 |
+
return self.dropout(x), self.dropout(pos_emb)
|
131 |
+
|
132 |
+
|
133 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
134 |
+
"""
|
135 |
+
Scaled positional encoding module.
|
136 |
+
|
137 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
138 |
+
|
139 |
+
Args:
|
140 |
+
d_model (int): Embedding dimension.
|
141 |
+
dropout_rate (float): Dropout rate.
|
142 |
+
max_len (int): Maximum input length.
|
143 |
+
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
147 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
148 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
149 |
+
|
150 |
+
def reset_parameters(self):
|
151 |
+
self.alpha.data = torch.tensor(1.0)
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
"""
|
155 |
+
Add positional encoding.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
162 |
+
|
163 |
+
"""
|
164 |
+
self.extend_pe(x)
|
165 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
166 |
+
return self.dropout(x)
|
Layers/PositionwiseFeedForward.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Written by Shigeki Karita, 2019
|
2 |
+
# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux, 2021
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
idim (int): Input dimenstion.
|
13 |
+
hidden_units (int): The number of hidden units.
|
14 |
+
dropout_rate (float): Dropout rate.
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
19 |
+
super(PositionwiseFeedForward, self).__init__()
|
20 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
21 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
22 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
23 |
+
self.activation = activation
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
Layers/PostNet.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class PostNet(torch.nn.Module):
|
9 |
+
"""
|
10 |
+
From Tacotron2
|
11 |
+
|
12 |
+
Postnet module for Spectrogram prediction network.
|
13 |
+
|
14 |
+
This is a module of Postnet in Spectrogram prediction network,
|
15 |
+
which described in `Natural TTS Synthesis by
|
16 |
+
Conditioning WaveNet on Mel Spectrogram Predictions`_.
|
17 |
+
The Postnet refines the predicted
|
18 |
+
Mel-filterbank of the decoder,
|
19 |
+
which helps to compensate the detail sturcture of spectrogram.
|
20 |
+
|
21 |
+
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
|
22 |
+
https://arxiv.org/abs/1712.05884
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True):
|
26 |
+
"""
|
27 |
+
Initialize postnet module.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
idim (int): Dimension of the inputs.
|
31 |
+
odim (int): Dimension of the outputs.
|
32 |
+
n_layers (int, optional): The number of layers.
|
33 |
+
n_filts (int, optional): The number of filter size.
|
34 |
+
n_units (int, optional): The number of filter channels.
|
35 |
+
use_batch_norm (bool, optional): Whether to use batch normalization..
|
36 |
+
dropout_rate (float, optional): Dropout rate..
|
37 |
+
"""
|
38 |
+
super(PostNet, self).__init__()
|
39 |
+
self.postnet = torch.nn.ModuleList()
|
40 |
+
for layer in range(n_layers - 1):
|
41 |
+
ichans = odim if layer == 0 else n_chans
|
42 |
+
ochans = odim if layer == n_layers - 1 else n_chans
|
43 |
+
if use_batch_norm:
|
44 |
+
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
|
45 |
+
torch.nn.GroupNorm(num_groups=32, num_channels=ochans), torch.nn.Tanh(),
|
46 |
+
torch.nn.Dropout(dropout_rate), )]
|
47 |
+
|
48 |
+
else:
|
49 |
+
self.postnet += [
|
50 |
+
torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(),
|
51 |
+
torch.nn.Dropout(dropout_rate), )]
|
52 |
+
ichans = n_chans if n_layers != 1 else odim
|
53 |
+
if use_batch_norm:
|
54 |
+
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
|
55 |
+
torch.nn.GroupNorm(num_groups=20, num_channels=odim),
|
56 |
+
torch.nn.Dropout(dropout_rate), )]
|
57 |
+
|
58 |
+
else:
|
59 |
+
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ),
|
60 |
+
torch.nn.Dropout(dropout_rate), )]
|
61 |
+
|
62 |
+
def forward(self, xs):
|
63 |
+
"""
|
64 |
+
Calculate forward propagation.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Tensor: Batch of padded output tensor. (B, odim, Tmax).
|
71 |
+
"""
|
72 |
+
for i in range(len(self.postnet)):
|
73 |
+
xs = self.postnet[i](xs)
|
74 |
+
return xs
|
Layers/ResidualBlock.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""
|
4 |
+
References:
|
5 |
+
- https://github.com/jik876/hifi-gan
|
6 |
+
- https://github.com/kan-bayashi/ParallelWaveGAN
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class Conv1d(torch.nn.Conv1d):
|
13 |
+
"""
|
14 |
+
Conv1d module with customized initialization.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, *args, **kwargs):
|
18 |
+
super(Conv1d, self).__init__(*args, **kwargs)
|
19 |
+
|
20 |
+
def reset_parameters(self):
|
21 |
+
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
22 |
+
if self.bias is not None:
|
23 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
24 |
+
|
25 |
+
|
26 |
+
class Conv1d1x1(Conv1d):
|
27 |
+
"""
|
28 |
+
1x1 Conv1d with customized initialization.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, in_channels, out_channels, bias):
|
32 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias)
|
33 |
+
|
34 |
+
|
35 |
+
class HiFiGANResidualBlock(torch.nn.Module):
|
36 |
+
"""Residual block module in HiFiGAN."""
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
kernel_size=3,
|
40 |
+
channels=512,
|
41 |
+
dilations=(1, 3, 5),
|
42 |
+
bias=True,
|
43 |
+
use_additional_convs=True,
|
44 |
+
nonlinear_activation="LeakyReLU",
|
45 |
+
nonlinear_activation_params={"negative_slope": 0.1}, ):
|
46 |
+
"""
|
47 |
+
Initialize HiFiGANResidualBlock module.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
51 |
+
channels (int): Number of channels for convolution layer.
|
52 |
+
dilations (List[int]): List of dilation factors.
|
53 |
+
use_additional_convs (bool): Whether to use additional convolution layers.
|
54 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
55 |
+
nonlinear_activation (str): Activation function module name.
|
56 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
57 |
+
"""
|
58 |
+
super().__init__()
|
59 |
+
self.use_additional_convs = use_additional_convs
|
60 |
+
self.convs1 = torch.nn.ModuleList()
|
61 |
+
if use_additional_convs:
|
62 |
+
self.convs2 = torch.nn.ModuleList()
|
63 |
+
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
64 |
+
for dilation in dilations:
|
65 |
+
self.convs1 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
66 |
+
torch.nn.Conv1d(channels,
|
67 |
+
channels,
|
68 |
+
kernel_size,
|
69 |
+
1,
|
70 |
+
dilation=dilation,
|
71 |
+
bias=bias,
|
72 |
+
padding=(kernel_size - 1) // 2 * dilation, ), )]
|
73 |
+
if use_additional_convs:
|
74 |
+
self.convs2 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
75 |
+
torch.nn.Conv1d(channels,
|
76 |
+
channels,
|
77 |
+
kernel_size,
|
78 |
+
1,
|
79 |
+
dilation=1,
|
80 |
+
bias=bias,
|
81 |
+
padding=(kernel_size - 1) // 2, ), )]
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
"""
|
85 |
+
Calculate forward propagation.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
x (Tensor): Input tensor (B, channels, T).
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Tensor: Output tensor (B, channels, T).
|
92 |
+
"""
|
93 |
+
for idx in range(len(self.convs1)):
|
94 |
+
xt = self.convs1[idx](x)
|
95 |
+
if self.use_additional_convs:
|
96 |
+
xt = self.convs2[idx](xt)
|
97 |
+
x = xt + x
|
98 |
+
return x
|
Layers/ResidualStack.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class ResidualStack(torch.nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, kernel_size=3, channels=32, dilation=1, bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2},
|
12 |
+
pad="ReflectionPad1d", pad_params={}, ):
|
13 |
+
"""
|
14 |
+
Initialize ResidualStack module.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
18 |
+
channels (int): Number of channels of convolution layers.
|
19 |
+
dilation (int): Dilation factor.
|
20 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
21 |
+
nonlinear_activation (str): Activation function module name.
|
22 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
23 |
+
pad (str): Padding function module name before dilated convolution layer.
|
24 |
+
pad_params (dict): Hyperparameters for padding function.
|
25 |
+
|
26 |
+
"""
|
27 |
+
super(ResidualStack, self).__init__()
|
28 |
+
|
29 |
+
# defile residual stack part
|
30 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
31 |
+
self.stack = torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
32 |
+
getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
|
33 |
+
torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
|
34 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
35 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias), )
|
36 |
+
|
37 |
+
# defile extra layer for skip connection
|
38 |
+
self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
|
39 |
+
|
40 |
+
def forward(self, c):
|
41 |
+
"""
|
42 |
+
Calculate forward propagation.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
c (Tensor): Input tensor (B, channels, T).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Tensor: Output tensor (B, chennels, T).
|
49 |
+
|
50 |
+
"""
|
51 |
+
return self.stack(c) + self.skip_layer(c)
|
Layers/STFT.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.functional import stft as torch_stft
|
7 |
+
from torch_complex.tensor import ComplexTensor
|
8 |
+
|
9 |
+
from Utility.utils import make_pad_mask
|
10 |
+
|
11 |
+
|
12 |
+
class STFT(torch.nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, n_fft=512, win_length=None, hop_length=128, window="hann", center=True, normalized=False,
|
15 |
+
onesided=True):
|
16 |
+
super().__init__()
|
17 |
+
self.n_fft = n_fft
|
18 |
+
if win_length is None:
|
19 |
+
self.win_length = n_fft
|
20 |
+
else:
|
21 |
+
self.win_length = win_length
|
22 |
+
self.hop_length = hop_length
|
23 |
+
self.center = center
|
24 |
+
self.normalized = normalized
|
25 |
+
self.onesided = onesided
|
26 |
+
self.window = window
|
27 |
+
|
28 |
+
def extra_repr(self):
|
29 |
+
return (f"n_fft={self.n_fft}, "
|
30 |
+
f"win_length={self.win_length}, "
|
31 |
+
f"hop_length={self.hop_length}, "
|
32 |
+
f"center={self.center}, "
|
33 |
+
f"normalized={self.normalized}, "
|
34 |
+
f"onesided={self.onesided}")
|
35 |
+
|
36 |
+
def forward(self, input_wave, ilens=None):
|
37 |
+
"""
|
38 |
+
STFT forward function.
|
39 |
+
Args:
|
40 |
+
input_wave: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
41 |
+
ilens: (Batch)
|
42 |
+
Returns:
|
43 |
+
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
44 |
+
"""
|
45 |
+
bs = input_wave.size(0)
|
46 |
+
|
47 |
+
if input_wave.dim() == 3:
|
48 |
+
multi_channel = True
|
49 |
+
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
50 |
+
input_wave = input_wave.transpose(1, 2).reshape(-1, input_wave.size(1))
|
51 |
+
else:
|
52 |
+
multi_channel = False
|
53 |
+
|
54 |
+
# output: (Batch, Freq, Frames, 2=real_imag)
|
55 |
+
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
56 |
+
if self.window is not None:
|
57 |
+
window_func = getattr(torch, f"{self.window}_window")
|
58 |
+
window = window_func(self.win_length, dtype=input_wave.dtype, device=input_wave.device)
|
59 |
+
else:
|
60 |
+
window = None
|
61 |
+
|
62 |
+
complex_output = torch_stft(input=input_wave,
|
63 |
+
n_fft=self.n_fft,
|
64 |
+
win_length=self.win_length,
|
65 |
+
hop_length=self.hop_length,
|
66 |
+
center=self.center,
|
67 |
+
window=window,
|
68 |
+
normalized=self.normalized,
|
69 |
+
onesided=self.onesided,
|
70 |
+
return_complex=True)
|
71 |
+
output = torch.view_as_real(complex_output)
|
72 |
+
# output: (Batch, Freq, Frames, 2=real_imag)
|
73 |
+
# -> (Batch, Frames, Freq, 2=real_imag)
|
74 |
+
output = output.transpose(1, 2)
|
75 |
+
if multi_channel:
|
76 |
+
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
77 |
+
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
78 |
+
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2)
|
79 |
+
|
80 |
+
if ilens is not None:
|
81 |
+
if self.center:
|
82 |
+
pad = self.win_length // 2
|
83 |
+
ilens = ilens + 2 * pad
|
84 |
+
|
85 |
+
olens = torch.div((ilens - self.win_length), self.hop_length, rounding_mode="trunc") + 1
|
86 |
+
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
87 |
+
else:
|
88 |
+
olens = None
|
89 |
+
|
90 |
+
return output, olens
|
91 |
+
|
92 |
+
def inverse(self, input, ilens=None):
|
93 |
+
"""
|
94 |
+
Inverse STFT.
|
95 |
+
Args:
|
96 |
+
input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
|
97 |
+
ilens: (batch,)
|
98 |
+
Returns:
|
99 |
+
wavs: (batch, samples)
|
100 |
+
ilens: (batch,)
|
101 |
+
"""
|
102 |
+
istft = torch.functional.istft
|
103 |
+
|
104 |
+
if self.window is not None:
|
105 |
+
window_func = getattr(torch, f"{self.window}_window")
|
106 |
+
window = window_func(self.win_length, dtype=input.dtype, device=input.device)
|
107 |
+
else:
|
108 |
+
window = None
|
109 |
+
|
110 |
+
if isinstance(input, ComplexTensor):
|
111 |
+
input = torch.stack([input.real, input.imag], dim=-1)
|
112 |
+
assert input.shape[-1] == 2
|
113 |
+
input = input.transpose(1, 2)
|
114 |
+
|
115 |
+
wavs = istft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=window, center=self.center,
|
116 |
+
normalized=self.normalized, onesided=self.onesided, length=ilens.max() if ilens is not None else ilens)
|
117 |
+
|
118 |
+
return wavs, ilens
|
Layers/Swish.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Northwestern Polytechnical University (Pengcheng Guo)
|
3 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
4 |
+
# Adapted by Florian Lux 2021
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Swish(torch.nn.Module):
|
10 |
+
"""
|
11 |
+
Construct an Swish activation function for Conformer.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
"""
|
16 |
+
Return Swish activation function.
|
17 |
+
"""
|
18 |
+
return x * torch.sigmoid(x)
|
Layers/VariancePredictor.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
from abc import ABC
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from Layers.LayerNorm import LayerNorm
|
10 |
+
|
11 |
+
|
12 |
+
class VariancePredictor(torch.nn.Module, ABC):
|
13 |
+
"""
|
14 |
+
Variance predictor module.
|
15 |
+
|
16 |
+
This is a module of variance predictor described in `FastSpeech 2:
|
17 |
+
Fast and High-Quality End-to-End Text to Speech`_.
|
18 |
+
|
19 |
+
.. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`:
|
20 |
+
https://arxiv.org/abs/2006.04558
|
21 |
+
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, bias=True, dropout_rate=0.5, ):
|
25 |
+
"""
|
26 |
+
Initilize duration predictor module.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
idim (int): Input dimension.
|
30 |
+
n_layers (int, optional): Number of convolutional layers.
|
31 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
32 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
33 |
+
dropout_rate (float, optional): Dropout rate.
|
34 |
+
"""
|
35 |
+
super().__init__()
|
36 |
+
self.conv = torch.nn.ModuleList()
|
37 |
+
for idx in range(n_layers):
|
38 |
+
in_chans = idim if idx == 0 else n_chans
|
39 |
+
self.conv += [
|
40 |
+
torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias, ), torch.nn.ReLU(),
|
41 |
+
LayerNorm(n_chans, dim=1), torch.nn.Dropout(dropout_rate), )]
|
42 |
+
self.linear = torch.nn.Linear(n_chans, 1)
|
43 |
+
|
44 |
+
def forward(self, xs, x_masks=None):
|
45 |
+
"""
|
46 |
+
Calculate forward propagation.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
50 |
+
x_masks (ByteTensor, optional):
|
51 |
+
Batch of masks indicating padded part (B, Tmax).
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Tensor: Batch of predicted sequences (B, Tmax, 1).
|
55 |
+
"""
|
56 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
57 |
+
for f in self.conv:
|
58 |
+
xs = f(xs) # (B, C, Tmax)
|
59 |
+
|
60 |
+
xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1)
|
61 |
+
|
62 |
+
if x_masks is not None:
|
63 |
+
xs = xs.masked_fill(x_masks, 0.0)
|
64 |
+
|
65 |
+
return xs
|
Layers/__init__.py
ADDED
File without changes
|
Models/Aligner/__init__.py
ADDED
File without changes
|
Models/FastSpeech2_Meta/__init__.py
ADDED
File without changes
|
Models/HiFiGAN_combined/__init__.py
ADDED
File without changes
|
Preprocessing/ArticulatoryCombinedTextFrontend.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import panphon
|
5 |
+
import phonemizer
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from Preprocessing.papercup_features import generate_feature_table
|
9 |
+
|
10 |
+
|
11 |
+
class ArticulatoryCombinedTextFrontend:
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
language,
|
15 |
+
use_word_boundaries=False, # goes together well with
|
16 |
+
# parallel models and an aligner. Doesn't go together
|
17 |
+
# well with autoregressive models.
|
18 |
+
use_explicit_eos=True,
|
19 |
+
use_prosody=False, # unfortunately the non-segmental
|
20 |
+
# nature of prosodic markers mixed with the sequential
|
21 |
+
# phonemes hurts the performance of end-to-end models a
|
22 |
+
# lot, even though one might think enriching the input
|
23 |
+
# with such information would help.
|
24 |
+
use_lexical_stress=False,
|
25 |
+
silent=True,
|
26 |
+
allow_unknown=False,
|
27 |
+
add_silence_to_end=True,
|
28 |
+
strip_silence=True):
|
29 |
+
"""
|
30 |
+
Mostly preparing ID lookups
|
31 |
+
"""
|
32 |
+
self.strip_silence = strip_silence
|
33 |
+
self.use_word_boundaries = use_word_boundaries
|
34 |
+
self.allow_unknown = allow_unknown
|
35 |
+
self.use_explicit_eos = use_explicit_eos
|
36 |
+
self.use_prosody = use_prosody
|
37 |
+
self.use_stress = use_lexical_stress
|
38 |
+
self.add_silence_to_end = add_silence_to_end
|
39 |
+
self.feature_table = panphon.FeatureTable()
|
40 |
+
|
41 |
+
if language == "en":
|
42 |
+
self.g2p_lang = "en-us"
|
43 |
+
self.expand_abbreviations = english_text_expansion
|
44 |
+
if not silent:
|
45 |
+
print("Created an English Text-Frontend")
|
46 |
+
|
47 |
+
elif language == "de":
|
48 |
+
self.g2p_lang = "de"
|
49 |
+
self.expand_abbreviations = lambda x: x
|
50 |
+
if not silent:
|
51 |
+
print("Created a German Text-Frontend")
|
52 |
+
|
53 |
+
elif language == "el":
|
54 |
+
self.g2p_lang = "el"
|
55 |
+
self.expand_abbreviations = lambda x: x
|
56 |
+
if not silent:
|
57 |
+
print("Created a Greek Text-Frontend")
|
58 |
+
|
59 |
+
elif language == "es":
|
60 |
+
self.g2p_lang = "es"
|
61 |
+
self.expand_abbreviations = lambda x: x
|
62 |
+
if not silent:
|
63 |
+
print("Created a Spanish Text-Frontend")
|
64 |
+
|
65 |
+
elif language == "fi":
|
66 |
+
self.g2p_lang = "fi"
|
67 |
+
self.expand_abbreviations = lambda x: x
|
68 |
+
if not silent:
|
69 |
+
print("Created a Finnish Text-Frontend")
|
70 |
+
|
71 |
+
elif language == "ru":
|
72 |
+
self.g2p_lang = "ru"
|
73 |
+
self.expand_abbreviations = lambda x: x
|
74 |
+
if not silent:
|
75 |
+
print("Created a Russian Text-Frontend")
|
76 |
+
|
77 |
+
elif language == "hu":
|
78 |
+
self.g2p_lang = "hu"
|
79 |
+
self.expand_abbreviations = lambda x: x
|
80 |
+
if not silent:
|
81 |
+
print("Created a Hungarian Text-Frontend")
|
82 |
+
|
83 |
+
elif language == "nl":
|
84 |
+
self.g2p_lang = "nl"
|
85 |
+
self.expand_abbreviations = lambda x: x
|
86 |
+
if not silent:
|
87 |
+
print("Created a Dutch Text-Frontend")
|
88 |
+
|
89 |
+
elif language == "fr":
|
90 |
+
self.g2p_lang = "fr-fr"
|
91 |
+
self.expand_abbreviations = lambda x: x
|
92 |
+
if not silent:
|
93 |
+
print("Created a French Text-Frontend")
|
94 |
+
|
95 |
+
elif language == "it":
|
96 |
+
self.g2p_lang = "it"
|
97 |
+
self.expand_abbreviations = lambda x: x
|
98 |
+
if not silent:
|
99 |
+
print("Created a Italian Text-Frontend")
|
100 |
+
|
101 |
+
elif language == "pt":
|
102 |
+
self.g2p_lang = "pt"
|
103 |
+
self.expand_abbreviations = lambda x: x
|
104 |
+
if not silent:
|
105 |
+
print("Created a Portuguese Text-Frontend")
|
106 |
+
|
107 |
+
elif language == "pl":
|
108 |
+
self.g2p_lang = "pl"
|
109 |
+
self.expand_abbreviations = lambda x: x
|
110 |
+
if not silent:
|
111 |
+
print("Created a Polish Text-Frontend")
|
112 |
+
|
113 |
+
# remember to also update get_language_id() when adding something here
|
114 |
+
|
115 |
+
else:
|
116 |
+
print("Language not supported yet")
|
117 |
+
sys.exit()
|
118 |
+
|
119 |
+
self.phone_to_vector_papercup = generate_feature_table()
|
120 |
+
|
121 |
+
self.phone_to_vector = dict()
|
122 |
+
for phone in self.phone_to_vector_papercup:
|
123 |
+
panphon_features = self.feature_table.word_to_vector_list(phone, numeric=True)
|
124 |
+
if panphon_features == []:
|
125 |
+
panphon_features = [[0] * 24]
|
126 |
+
papercup_features = self.phone_to_vector_papercup[phone]
|
127 |
+
self.phone_to_vector[phone] = papercup_features + panphon_features[0]
|
128 |
+
|
129 |
+
self.phone_to_id = { # this lookup must be updated manually, because the only
|
130 |
+
# other way would be extracting them from a set, which can be non-deterministic
|
131 |
+
'~': 0,
|
132 |
+
'#': 1,
|
133 |
+
'?': 2,
|
134 |
+
'!': 3,
|
135 |
+
'.': 4,
|
136 |
+
'ɜ': 5,
|
137 |
+
'ɫ': 6,
|
138 |
+
'ə': 7,
|
139 |
+
'ɚ': 8,
|
140 |
+
'a': 9,
|
141 |
+
'ð': 10,
|
142 |
+
'ɛ': 11,
|
143 |
+
'ɪ': 12,
|
144 |
+
'ᵻ': 13,
|
145 |
+
'ŋ': 14,
|
146 |
+
'ɔ': 15,
|
147 |
+
'ɒ': 16,
|
148 |
+
'ɾ': 17,
|
149 |
+
'ʃ': 18,
|
150 |
+
'θ': 19,
|
151 |
+
'ʊ': 20,
|
152 |
+
'ʌ': 21,
|
153 |
+
'ʒ': 22,
|
154 |
+
'æ': 23,
|
155 |
+
'b': 24,
|
156 |
+
'ʔ': 25,
|
157 |
+
'd': 26,
|
158 |
+
'e': 27,
|
159 |
+
'f': 28,
|
160 |
+
'g': 29,
|
161 |
+
'h': 30,
|
162 |
+
'i': 31,
|
163 |
+
'j': 32,
|
164 |
+
'k': 33,
|
165 |
+
'l': 34,
|
166 |
+
'm': 35,
|
167 |
+
'n': 36,
|
168 |
+
'ɳ': 37,
|
169 |
+
'o': 38,
|
170 |
+
'p': 39,
|
171 |
+
'ɡ': 40,
|
172 |
+
'ɹ': 41,
|
173 |
+
'r': 42,
|
174 |
+
's': 43,
|
175 |
+
't': 44,
|
176 |
+
'u': 45,
|
177 |
+
'v': 46,
|
178 |
+
'w': 47,
|
179 |
+
'x': 48,
|
180 |
+
'z': 49,
|
181 |
+
'ʀ': 50,
|
182 |
+
'ø': 51,
|
183 |
+
'ç': 52,
|
184 |
+
'ɐ': 53,
|
185 |
+
'œ': 54,
|
186 |
+
'y': 55,
|
187 |
+
'ʏ': 56,
|
188 |
+
'ɑ': 57,
|
189 |
+
'c': 58,
|
190 |
+
'ɲ': 59,
|
191 |
+
'ɣ': 60,
|
192 |
+
'ʎ': 61,
|
193 |
+
'β': 62,
|
194 |
+
'ʝ': 63,
|
195 |
+
'ɟ': 64,
|
196 |
+
'q': 65,
|
197 |
+
'ɕ': 66,
|
198 |
+
'ʲ': 67,
|
199 |
+
'ɭ': 68,
|
200 |
+
'ɵ': 69,
|
201 |
+
'ʑ': 70,
|
202 |
+
'ʋ': 71,
|
203 |
+
'ʁ': 72,
|
204 |
+
'ɨ': 73,
|
205 |
+
'ʂ': 74,
|
206 |
+
'ɬ': 75,
|
207 |
+
} # for the states of the ctc loss and dijkstra/mas in the aligner
|
208 |
+
|
209 |
+
self.id_to_phone = {v: k for k, v in self.phone_to_id.items()}
|
210 |
+
|
211 |
+
def string_to_tensor(self, text, view=False, device="cpu", handle_missing=True, input_phonemes=False):
|
212 |
+
"""
|
213 |
+
Fixes unicode errors, expands some abbreviations,
|
214 |
+
turns graphemes into phonemes and then vectorizes
|
215 |
+
the sequence as articulatory features
|
216 |
+
"""
|
217 |
+
if input_phonemes:
|
218 |
+
phones = text
|
219 |
+
else:
|
220 |
+
phones = self.get_phone_string(text=text, include_eos_symbol=True)
|
221 |
+
if view:
|
222 |
+
print("Phonemes: \n{}\n".format(phones))
|
223 |
+
phones_vector = list()
|
224 |
+
# turn into numeric vectors
|
225 |
+
for char in phones:
|
226 |
+
if handle_missing:
|
227 |
+
try:
|
228 |
+
phones_vector.append(self.phone_to_vector[char])
|
229 |
+
except KeyError:
|
230 |
+
print("unknown phoneme: {}".format(char))
|
231 |
+
else:
|
232 |
+
phones_vector.append(self.phone_to_vector[char]) # leave error handling to elsewhere
|
233 |
+
|
234 |
+
return torch.Tensor(phones_vector, device=device)
|
235 |
+
|
236 |
+
def get_phone_string(self, text, include_eos_symbol=True):
|
237 |
+
# expand abbreviations
|
238 |
+
utt = self.expand_abbreviations(text)
|
239 |
+
# phonemize
|
240 |
+
phones = phonemizer.phonemize(utt,
|
241 |
+
language_switch='remove-flags',
|
242 |
+
backend="espeak",
|
243 |
+
language=self.g2p_lang,
|
244 |
+
preserve_punctuation=True,
|
245 |
+
strip=True,
|
246 |
+
punctuation_marks=';:,.!?¡¿—…"«»“”~/',
|
247 |
+
with_stress=self.use_stress).replace(";", ",").replace("/", " ").replace("—", "") \
|
248 |
+
.replace(":", ",").replace('"', ",").replace("-", ",").replace("...", ",").replace("-", ",").replace("\n", " ") \
|
249 |
+
.replace("\t", " ").replace("¡", "").replace("¿", "").replace(",", "~").replace(" ̃", "").replace('̩', "").replace("̃", "").replace("̪", "")
|
250 |
+
# less than 1 wide characters hidden here
|
251 |
+
phones = re.sub("~+", "~", phones)
|
252 |
+
if not self.use_prosody:
|
253 |
+
# retain ~ as heuristic pause marker, even though all other symbols are removed with this option.
|
254 |
+
# also retain . ? and ! since they can be indicators for the stop token
|
255 |
+
phones = phones.replace("ˌ", "").replace("ː", "").replace("ˑ", "") \
|
256 |
+
.replace("˘", "").replace("|", "").replace("‖", "")
|
257 |
+
if not self.use_word_boundaries:
|
258 |
+
phones = phones.replace(" ", "")
|
259 |
+
else:
|
260 |
+
phones = re.sub(r"\s+", " ", phones)
|
261 |
+
phones = re.sub(" ", "~", phones)
|
262 |
+
if self.strip_silence:
|
263 |
+
phones = phones.lstrip("~").rstrip("~")
|
264 |
+
if self.add_silence_to_end:
|
265 |
+
phones += "~" # adding a silence in the end during add_silence_to_end produces more natural sounding prosody
|
266 |
+
if include_eos_symbol:
|
267 |
+
phones += "#"
|
268 |
+
|
269 |
+
phones = "~" + phones
|
270 |
+
phones = re.sub("~+", "~", phones)
|
271 |
+
|
272 |
+
return phones
|
273 |
+
|
274 |
+
|
275 |
+
def english_text_expansion(text):
|
276 |
+
"""
|
277 |
+
Apply as small part of the tacotron style text cleaning pipeline, suitable for e.g. LJSpeech.
|
278 |
+
See https://github.com/keithito/tacotron/
|
279 |
+
Careful: Only apply to english datasets. Different languages need different cleaners.
|
280 |
+
"""
|
281 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in
|
282 |
+
[('Mrs.', 'misess'), ('Mr.', 'mister'), ('Dr.', 'doctor'), ('St.', 'saint'), ('Co.', 'company'), ('Jr.', 'junior'), ('Maj.', 'major'),
|
283 |
+
('Gen.', 'general'), ('Drs.', 'doctors'), ('Rev.', 'reverend'), ('Lt.', 'lieutenant'), ('Hon.', 'honorable'), ('Sgt.', 'sergeant'),
|
284 |
+
('Capt.', 'captain'), ('Esq.', 'esquire'), ('Ltd.', 'limited'), ('Col.', 'colonel'), ('Ft.', 'fort')]]
|
285 |
+
for regex, replacement in _abbreviations:
|
286 |
+
text = re.sub(regex, replacement, text)
|
287 |
+
return text
|
288 |
+
|
289 |
+
|
290 |
+
def get_language_id(language):
|
291 |
+
if language == "en":
|
292 |
+
return torch.LongTensor([0])
|
293 |
+
elif language == "de":
|
294 |
+
return torch.LongTensor([1])
|
295 |
+
elif language == "el":
|
296 |
+
return torch.LongTensor([2])
|
297 |
+
elif language == "es":
|
298 |
+
return torch.LongTensor([3])
|
299 |
+
elif language == "fi":
|
300 |
+
return torch.LongTensor([4])
|
301 |
+
elif language == "ru":
|
302 |
+
return torch.LongTensor([5])
|
303 |
+
elif language == "hu":
|
304 |
+
return torch.LongTensor([6])
|
305 |
+
elif language == "nl":
|
306 |
+
return torch.LongTensor([7])
|
307 |
+
elif language == "fr":
|
308 |
+
return torch.LongTensor([8])
|
309 |
+
elif language == "pt":
|
310 |
+
return torch.LongTensor([9])
|
311 |
+
elif language == "pl":
|
312 |
+
return torch.LongTensor([10])
|
313 |
+
elif language == "it":
|
314 |
+
return torch.LongTensor([11])
|
315 |
+
|
316 |
+
|
317 |
+
if __name__ == '__main__':
|
318 |
+
# test an English utterance
|
319 |
+
tfr_en = ArticulatoryCombinedTextFrontend(language="en")
|
320 |
+
print(tfr_en.string_to_tensor("This is a complex sentence, it even has a pause! But can it do this? Nice.", view=True))
|
321 |
+
|
322 |
+
tfr_en = ArticulatoryCombinedTextFrontend(language="de")
|
323 |
+
print(tfr_en.string_to_tensor("Alles klar, jetzt testen wir einen deutschen Satz. Ich hoffe es gibt nicht mehr viele unspezifizierte Phoneme.", view=True))
|
Preprocessing/AudioPreprocessor.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.core as lb
|
3 |
+
import librosa.display as lbd
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy
|
6 |
+
import numpy as np
|
7 |
+
import pyloudnorm as pyln
|
8 |
+
import torch
|
9 |
+
from torchaudio.transforms import Resample
|
10 |
+
|
11 |
+
|
12 |
+
class AudioPreprocessor:
|
13 |
+
|
14 |
+
def __init__(self, input_sr, output_sr=None, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False, device="cpu"):
|
15 |
+
"""
|
16 |
+
The parameters are by default set up to do well
|
17 |
+
on a 16kHz signal. A different sampling rate may
|
18 |
+
require different hop_length and n_fft (e.g.
|
19 |
+
doubling frequency --> doubling hop_length and
|
20 |
+
doubling n_fft)
|
21 |
+
"""
|
22 |
+
self.cut_silence = cut_silence
|
23 |
+
self.device = device
|
24 |
+
self.sr = input_sr
|
25 |
+
self.new_sr = output_sr
|
26 |
+
self.hop_length = hop_length
|
27 |
+
self.n_fft = n_fft
|
28 |
+
self.mel_buckets = melspec_buckets
|
29 |
+
self.meter = pyln.Meter(input_sr)
|
30 |
+
self.final_sr = input_sr
|
31 |
+
if cut_silence:
|
32 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
|
33 |
+
# careful: assumes 16kHz or 8kHz audio
|
34 |
+
self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
35 |
+
model='silero_vad',
|
36 |
+
force_reload=False,
|
37 |
+
onnx=False,
|
38 |
+
verbose=False)
|
39 |
+
(self.get_speech_timestamps,
|
40 |
+
self.save_audio,
|
41 |
+
self.read_audio,
|
42 |
+
self.VADIterator,
|
43 |
+
self.collect_chunks) = utils
|
44 |
+
self.silero_model = self.silero_model.to(self.device)
|
45 |
+
if output_sr is not None and output_sr != input_sr:
|
46 |
+
self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
|
47 |
+
self.final_sr = output_sr
|
48 |
+
else:
|
49 |
+
self.resample = lambda x: x
|
50 |
+
|
51 |
+
def cut_silence_from_audio(self, audio):
|
52 |
+
"""
|
53 |
+
https://github.com/snakers4/silero-vad
|
54 |
+
"""
|
55 |
+
return self.collect_chunks(self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr), audio)
|
56 |
+
|
57 |
+
def to_mono(self, x):
|
58 |
+
"""
|
59 |
+
make sure we deal with a 1D array
|
60 |
+
"""
|
61 |
+
if len(x.shape) == 2:
|
62 |
+
return lb.to_mono(numpy.transpose(x))
|
63 |
+
else:
|
64 |
+
return x
|
65 |
+
|
66 |
+
def normalize_loudness(self, audio):
|
67 |
+
"""
|
68 |
+
normalize the amplitudes according to
|
69 |
+
their decibels, so this should turn any
|
70 |
+
signal with different magnitudes into
|
71 |
+
the same magnitude by analysing loudness
|
72 |
+
"""
|
73 |
+
loudness = self.meter.integrated_loudness(audio)
|
74 |
+
loud_normed = pyln.normalize.loudness(audio, loudness, -30.0)
|
75 |
+
peak = numpy.amax(numpy.abs(loud_normed))
|
76 |
+
peak_normed = numpy.divide(loud_normed, peak)
|
77 |
+
return peak_normed
|
78 |
+
|
79 |
+
def logmelfilterbank(self, audio, sampling_rate, fmin=40, fmax=8000, eps=1e-10):
|
80 |
+
"""
|
81 |
+
Compute log-Mel filterbank
|
82 |
+
|
83 |
+
one day this could be replaced by torchaudio's internal log10(melspec(audio)), but
|
84 |
+
for some reason it gives slightly different results, so in order not to break backwards
|
85 |
+
compatibility, this is kept for now. If there is ever a reason to completely re-train
|
86 |
+
all models, this would be a good opportunity to make the switch.
|
87 |
+
"""
|
88 |
+
if isinstance(audio, torch.Tensor):
|
89 |
+
audio = audio.numpy()
|
90 |
+
# get amplitude spectrogram
|
91 |
+
x_stft = librosa.stft(audio, n_fft=self.n_fft, hop_length=self.hop_length, win_length=None, window="hann", pad_mode="reflect")
|
92 |
+
spc = np.abs(x_stft).T
|
93 |
+
# get mel basis
|
94 |
+
fmin = 0 if fmin is None else fmin
|
95 |
+
fmax = sampling_rate / 2 if fmax is None else fmax
|
96 |
+
mel_basis = librosa.filters.mel(sampling_rate, self.n_fft, self.mel_buckets, fmin, fmax)
|
97 |
+
# apply log and return
|
98 |
+
return torch.Tensor(np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))).transpose(0, 1)
|
99 |
+
|
100 |
+
def normalize_audio(self, audio):
|
101 |
+
"""
|
102 |
+
one function to apply them all in an
|
103 |
+
order that makes sense.
|
104 |
+
"""
|
105 |
+
audio = self.to_mono(audio)
|
106 |
+
audio = self.normalize_loudness(audio)
|
107 |
+
audio = torch.Tensor(audio).to(self.device)
|
108 |
+
audio = self.resample(audio)
|
109 |
+
if self.cut_silence:
|
110 |
+
audio = self.cut_silence_from_audio(audio)
|
111 |
+
return audio.to("cpu")
|
112 |
+
|
113 |
+
def visualize_cleaning(self, unclean_audio):
|
114 |
+
"""
|
115 |
+
displays Mel Spectrogram of unclean audio
|
116 |
+
and then displays Mel Spectrogram of the
|
117 |
+
cleaned version.
|
118 |
+
"""
|
119 |
+
fig, ax = plt.subplots(nrows=2, ncols=1)
|
120 |
+
unclean_audio_mono = self.to_mono(unclean_audio)
|
121 |
+
unclean_spec = self.audio_to_mel_spec_tensor(unclean_audio_mono, normalize=False).numpy()
|
122 |
+
clean_spec = self.audio_to_mel_spec_tensor(unclean_audio_mono, normalize=True).numpy()
|
123 |
+
lbd.specshow(unclean_spec, sr=self.sr, cmap='GnBu', y_axis='mel', ax=ax[0], x_axis='time')
|
124 |
+
ax[0].set(title='Uncleaned Audio')
|
125 |
+
ax[0].label_outer()
|
126 |
+
if self.new_sr is not None:
|
127 |
+
lbd.specshow(clean_spec, sr=self.new_sr, cmap='GnBu', y_axis='mel', ax=ax[1], x_axis='time')
|
128 |
+
else:
|
129 |
+
lbd.specshow(clean_spec, sr=self.sr, cmap='GnBu', y_axis='mel', ax=ax[1], x_axis='time')
|
130 |
+
ax[1].set(title='Cleaned Audio')
|
131 |
+
ax[1].label_outer()
|
132 |
+
plt.show()
|
133 |
+
|
134 |
+
def audio_to_wave_tensor(self, audio, normalize=True):
|
135 |
+
if normalize:
|
136 |
+
return self.normalize_audio(audio)
|
137 |
+
else:
|
138 |
+
if isinstance(audio, torch.Tensor):
|
139 |
+
return audio
|
140 |
+
else:
|
141 |
+
return torch.Tensor(audio)
|
142 |
+
|
143 |
+
def audio_to_mel_spec_tensor(self, audio, normalize=True, explicit_sampling_rate=None):
|
144 |
+
"""
|
145 |
+
explicit_sampling_rate is for when
|
146 |
+
normalization has already been applied
|
147 |
+
and that included resampling. No way
|
148 |
+
to detect the current sr of the incoming
|
149 |
+
audio
|
150 |
+
"""
|
151 |
+
if explicit_sampling_rate is None:
|
152 |
+
if normalize:
|
153 |
+
audio = self.normalize_audio(audio)
|
154 |
+
return self.logmelfilterbank(audio=audio, sampling_rate=self.final_sr)
|
155 |
+
return self.logmelfilterbank(audio=audio, sampling_rate=self.sr)
|
156 |
+
if normalize:
|
157 |
+
audio = self.normalize_audio(audio)
|
158 |
+
return self.logmelfilterbank(audio=audio, sampling_rate=explicit_sampling_rate)
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == '__main__':
|
162 |
+
import soundfile
|
163 |
+
|
164 |
+
wav, sr = soundfile.read("../audios/test.wav")
|
165 |
+
ap = AudioPreprocessor(input_sr=sr, output_sr=16000)
|
166 |
+
ap.visualize_cleaning(wav)
|
Preprocessing/ProsodicConditionExtractor.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import soundfile as sf
|
2 |
+
import torch
|
3 |
+
import torch.multiprocessing
|
4 |
+
import torch.multiprocessing
|
5 |
+
from numpy import trim_zeros
|
6 |
+
from speechbrain.pretrained import EncoderClassifier
|
7 |
+
|
8 |
+
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
9 |
+
|
10 |
+
|
11 |
+
class ProsodicConditionExtractor:
|
12 |
+
|
13 |
+
def __init__(self, sr, device=torch.device("cpu")):
|
14 |
+
self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False)
|
15 |
+
# https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb
|
16 |
+
self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
|
17 |
+
run_opts={"device": str(device)},
|
18 |
+
savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
|
19 |
+
# https://huggingface.co/speechbrain/spkrec-xvect-voxceleb
|
20 |
+
self.speaker_embedding_func_xvector = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb",
|
21 |
+
run_opts={"device": str(device)},
|
22 |
+
savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_xvector")
|
23 |
+
|
24 |
+
def extract_condition_from_reference_wave(self, wave, already_normalized=False):
|
25 |
+
if already_normalized:
|
26 |
+
norm_wave = wave
|
27 |
+
else:
|
28 |
+
norm_wave = self.ap.audio_to_wave_tensor(normalize=True, audio=wave)
|
29 |
+
norm_wave = torch.tensor(trim_zeros(norm_wave.numpy()))
|
30 |
+
spk_emb_ecapa = self.speaker_embedding_func_ecapa.encode_batch(wavs=norm_wave.unsqueeze(0)).squeeze()
|
31 |
+
spk_emb_xvector = self.speaker_embedding_func_xvector.encode_batch(wavs=norm_wave.unsqueeze(0)).squeeze()
|
32 |
+
combined_utt_condition = torch.cat([spk_emb_ecapa.cpu(),
|
33 |
+
spk_emb_xvector.cpu()], dim=0)
|
34 |
+
return combined_utt_condition
|
35 |
+
|
36 |
+
|
37 |
+
if __name__ == '__main__':
|
38 |
+
wave, sr = sf.read("../audios/1.wav")
|
39 |
+
ext = ProsodicConditionExtractor(sr=sr)
|
40 |
+
print(ext.extract_condition_from_reference_wave(wave=wave).shape)
|
Preprocessing/__init__.py
ADDED
File without changes
|
Preprocessing/papercup_features.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Derived from an open-source resource provided by Papercup Technologies Limited
|
2 |
+
# Resource-Author: Marlene Staib
|
3 |
+
# Modified by Florian Lux, 2021
|
4 |
+
|
5 |
+
def generate_feature_lookup():
|
6 |
+
return {
|
7 |
+
'~': {'symbol_type': 'silence'},
|
8 |
+
'#': {'symbol_type': 'end of sentence'},
|
9 |
+
'?': {'symbol_type': 'questionmark'},
|
10 |
+
'!': {'symbol_type': 'exclamationmark'},
|
11 |
+
'.': {'symbol_type': 'fullstop'},
|
12 |
+
'ɜ': {
|
13 |
+
'symbol_type' : 'phoneme',
|
14 |
+
'vowel_consonant' : 'vowel',
|
15 |
+
'VUV' : 'voiced',
|
16 |
+
'vowel_frontness' : 'central',
|
17 |
+
'vowel_openness' : 'open-mid',
|
18 |
+
'vowel_roundedness': 'unrounded',
|
19 |
+
},
|
20 |
+
'ɫ': {
|
21 |
+
'symbol_type' : 'phoneme',
|
22 |
+
'vowel_consonant' : 'consonant',
|
23 |
+
'VUV' : 'voiced',
|
24 |
+
'consonant_place' : 'alveolar',
|
25 |
+
'consonant_manner': 'lateral-approximant',
|
26 |
+
},
|
27 |
+
'ə': {
|
28 |
+
'symbol_type' : 'phoneme',
|
29 |
+
'vowel_consonant' : 'vowel',
|
30 |
+
'VUV' : 'voiced',
|
31 |
+
'vowel_frontness' : 'central',
|
32 |
+
'vowel_openness' : 'mid',
|
33 |
+
'vowel_roundedness': 'unrounded',
|
34 |
+
},
|
35 |
+
'ɚ': {
|
36 |
+
'symbol_type' : 'phoneme',
|
37 |
+
'vowel_consonant' : 'vowel',
|
38 |
+
'VUV' : 'voiced',
|
39 |
+
'vowel_frontness' : 'central',
|
40 |
+
'vowel_openness' : 'mid',
|
41 |
+
'vowel_roundedness': 'unrounded',
|
42 |
+
},
|
43 |
+
'a': {
|
44 |
+
'symbol_type' : 'phoneme',
|
45 |
+
'vowel_consonant' : 'vowel',
|
46 |
+
'VUV' : 'voiced',
|
47 |
+
'vowel_frontness' : 'front',
|
48 |
+
'vowel_openness' : 'open',
|
49 |
+
'vowel_roundedness': 'unrounded',
|
50 |
+
},
|
51 |
+
'ð': {
|
52 |
+
'symbol_type' : 'phoneme',
|
53 |
+
'vowel_consonant' : 'consonant',
|
54 |
+
'VUV' : 'voiced',
|
55 |
+
'consonant_place' : 'dental',
|
56 |
+
'consonant_manner': 'fricative'
|
57 |
+
},
|
58 |
+
'ɛ': {
|
59 |
+
'symbol_type' : 'phoneme',
|
60 |
+
'vowel_consonant' : 'vowel',
|
61 |
+
'VUV' : 'voiced',
|
62 |
+
'vowel_frontness' : 'front',
|
63 |
+
'vowel_openness' : 'open-mid',
|
64 |
+
'vowel_roundedness': 'unrounded',
|
65 |
+
},
|
66 |
+
'ɪ': {
|
67 |
+
'symbol_type' : 'phoneme',
|
68 |
+
'vowel_consonant' : 'vowel',
|
69 |
+
'VUV' : 'voiced',
|
70 |
+
'vowel_frontness' : 'front_central',
|
71 |
+
'vowel_openness' : 'close_close-mid',
|
72 |
+
'vowel_roundedness': 'unrounded',
|
73 |
+
},
|
74 |
+
'ᵻ': {
|
75 |
+
'symbol_type' : 'phoneme',
|
76 |
+
'vowel_consonant' : 'vowel',
|
77 |
+
'VUV' : 'voiced',
|
78 |
+
'vowel_frontness' : 'central',
|
79 |
+
'vowel_openness' : 'close',
|
80 |
+
'vowel_roundedness': 'unrounded',
|
81 |
+
},
|
82 |
+
'ŋ': {
|
83 |
+
'symbol_type' : 'phoneme',
|
84 |
+
'vowel_consonant' : 'consonant',
|
85 |
+
'VUV' : 'voiced',
|
86 |
+
'consonant_place' : 'velar',
|
87 |
+
'consonant_manner': 'nasal'
|
88 |
+
},
|
89 |
+
'ɔ': {
|
90 |
+
'symbol_type' : 'phoneme',
|
91 |
+
'vowel_consonant' : 'vowel',
|
92 |
+
'VUV' : 'voiced',
|
93 |
+
'vowel_frontness' : 'back',
|
94 |
+
'vowel_openness' : 'open-mid',
|
95 |
+
'vowel_roundedness': 'rounded',
|
96 |
+
},
|
97 |
+
'ɒ': {
|
98 |
+
'symbol_type' : 'phoneme',
|
99 |
+
'vowel_consonant' : 'vowel',
|
100 |
+
'VUV' : 'voiced',
|
101 |
+
'vowel_frontness' : 'back',
|
102 |
+
'vowel_openness' : 'open',
|
103 |
+
'vowel_roundedness': 'rounded',
|
104 |
+
},
|
105 |
+
'ɾ': {
|
106 |
+
'symbol_type' : 'phoneme',
|
107 |
+
'vowel_consonant' : 'consonant',
|
108 |
+
'VUV' : 'voiced',
|
109 |
+
'consonant_place' : 'alveolar',
|
110 |
+
'consonant_manner': 'tap'
|
111 |
+
},
|
112 |
+
'ʃ': {
|
113 |
+
'symbol_type' : 'phoneme',
|
114 |
+
'vowel_consonant' : 'consonant',
|
115 |
+
'VUV' : 'unvoiced',
|
116 |
+
'consonant_place' : 'postalveolar',
|
117 |
+
'consonant_manner': 'fricative'
|
118 |
+
},
|
119 |
+
'θ': {
|
120 |
+
'symbol_type' : 'phoneme',
|
121 |
+
'vowel_consonant' : 'consonant',
|
122 |
+
'VUV' : 'unvoiced',
|
123 |
+
'consonant_place' : 'dental',
|
124 |
+
'consonant_manner': 'fricative'
|
125 |
+
},
|
126 |
+
'ʊ': {
|
127 |
+
'symbol_type' : 'phoneme',
|
128 |
+
'vowel_consonant' : 'vowel',
|
129 |
+
'VUV' : 'voiced',
|
130 |
+
'vowel_frontness' : 'central_back',
|
131 |
+
'vowel_openness' : 'close_close-mid',
|
132 |
+
'vowel_roundedness': 'unrounded'
|
133 |
+
},
|
134 |
+
'ʌ': {
|
135 |
+
'symbol_type' : 'phoneme',
|
136 |
+
'vowel_consonant' : 'vowel',
|
137 |
+
'VUV' : 'voiced',
|
138 |
+
'vowel_frontness' : 'back',
|
139 |
+
'vowel_openness' : 'open-mid',
|
140 |
+
'vowel_roundedness': 'unrounded'
|
141 |
+
},
|
142 |
+
'ʒ': {
|
143 |
+
'symbol_type' : 'phoneme',
|
144 |
+
'vowel_consonant' : 'consonant',
|
145 |
+
'VUV' : 'voiced',
|
146 |
+
'consonant_place' : 'postalveolar',
|
147 |
+
'consonant_manner': 'fricative'
|
148 |
+
},
|
149 |
+
'æ': {
|
150 |
+
'symbol_type' : 'phoneme',
|
151 |
+
'vowel_consonant' : 'vowel',
|
152 |
+
'VUV' : 'voiced',
|
153 |
+
'vowel_frontness' : 'front',
|
154 |
+
'vowel_openness' : 'open-mid_open',
|
155 |
+
'vowel_roundedness': 'unrounded'
|
156 |
+
},
|
157 |
+
'b': {
|
158 |
+
'symbol_type' : 'phoneme',
|
159 |
+
'vowel_consonant' : 'consonant',
|
160 |
+
'VUV' : 'voiced',
|
161 |
+
'consonant_place' : 'bilabial',
|
162 |
+
'consonant_manner': 'stop'
|
163 |
+
},
|
164 |
+
'ʔ': {
|
165 |
+
'symbol_type' : 'phoneme',
|
166 |
+
'vowel_consonant' : 'consonant',
|
167 |
+
'VUV' : 'unvoiced',
|
168 |
+
'consonant_place' : 'glottal',
|
169 |
+
'consonant_manner': 'stop'
|
170 |
+
},
|
171 |
+
'd': {
|
172 |
+
'symbol_type' : 'phoneme',
|
173 |
+
'vowel_consonant' : 'consonant',
|
174 |
+
'VUV' : 'voiced',
|
175 |
+
'consonant_place' : 'alveolar',
|
176 |
+
'consonant_manner': 'stop'
|
177 |
+
},
|
178 |
+
'e': {
|
179 |
+
'symbol_type' : 'phoneme',
|
180 |
+
'vowel_consonant' : 'vowel',
|
181 |
+
'VUV' : 'voiced',
|
182 |
+
'vowel_frontness' : 'front',
|
183 |
+
'vowel_openness' : 'close-mid',
|
184 |
+
'vowel_roundedness': 'unrounded'
|
185 |
+
},
|
186 |
+
'f': {
|
187 |
+
'symbol_type' : 'phoneme',
|
188 |
+
'vowel_consonant' : 'consonant',
|
189 |
+
'VUV' : 'unvoiced',
|
190 |
+
'consonant_place' : 'labiodental',
|
191 |
+
'consonant_manner': 'fricative'
|
192 |
+
},
|
193 |
+
'g': {
|
194 |
+
'symbol_type' : 'phoneme',
|
195 |
+
'vowel_consonant' : 'consonant',
|
196 |
+
'VUV' : 'voiced',
|
197 |
+
'consonant_place' : 'velar',
|
198 |
+
'consonant_manner': 'stop'
|
199 |
+
},
|
200 |
+
'h': {
|
201 |
+
'symbol_type' : 'phoneme',
|
202 |
+
'vowel_consonant' : 'consonant',
|
203 |
+
'VUV' : 'unvoiced',
|
204 |
+
'consonant_place' : 'glottal',
|
205 |
+
'consonant_manner': 'fricative'
|
206 |
+
},
|
207 |
+
'i': {
|
208 |
+
'symbol_type' : 'phoneme',
|
209 |
+
'vowel_consonant' : 'vowel',
|
210 |
+
'VUV' : 'voiced',
|
211 |
+
'vowel_frontness' : 'front',
|
212 |
+
'vowel_openness' : 'close',
|
213 |
+
'vowel_roundedness': 'unrounded'
|
214 |
+
},
|
215 |
+
'j': {
|
216 |
+
'symbol_type' : 'phoneme',
|
217 |
+
'vowel_consonant' : 'consonant',
|
218 |
+
'VUV' : 'voiced',
|
219 |
+
'consonant_place' : 'palatal',
|
220 |
+
'consonant_manner': 'approximant'
|
221 |
+
},
|
222 |
+
'k': {
|
223 |
+
'symbol_type' : 'phoneme',
|
224 |
+
'vowel_consonant' : 'consonant',
|
225 |
+
'VUV' : 'unvoiced',
|
226 |
+
'consonant_place' : 'velar',
|
227 |
+
'consonant_manner': 'stop'
|
228 |
+
},
|
229 |
+
'l': {
|
230 |
+
'symbol_type' : 'phoneme',
|
231 |
+
'vowel_consonant' : 'consonant',
|
232 |
+
'VUV' : 'voiced',
|
233 |
+
'consonant_place' : 'alveolar',
|
234 |
+
'consonant_manner': 'lateral-approximant'
|
235 |
+
},
|
236 |
+
'm': {
|
237 |
+
'symbol_type' : 'phoneme',
|
238 |
+
'vowel_consonant' : 'consonant',
|
239 |
+
'VUV' : 'voiced',
|
240 |
+
'consonant_place' : 'bilabial',
|
241 |
+
'consonant_manner': 'nasal'
|
242 |
+
},
|
243 |
+
'n': {
|
244 |
+
'symbol_type' : 'phoneme',
|
245 |
+
'vowel_consonant' : 'consonant',
|
246 |
+
'VUV' : 'voiced',
|
247 |
+
'consonant_place' : 'alveolar',
|
248 |
+
'consonant_manner': 'nasal'
|
249 |
+
},
|
250 |
+
'ɳ': {
|
251 |
+
'symbol_type' : 'phoneme',
|
252 |
+
'vowel_consonant' : 'consonant',
|
253 |
+
'VUV' : 'voiced',
|
254 |
+
'consonant_place' : 'palatal',
|
255 |
+
'consonant_manner': 'nasal'
|
256 |
+
},
|
257 |
+
'o': {
|
258 |
+
'symbol_type' : 'phoneme',
|
259 |
+
'vowel_consonant' : 'vowel',
|
260 |
+
'VUV' : 'voiced',
|
261 |
+
'vowel_frontness' : 'back',
|
262 |
+
'vowel_openness' : 'close-mid',
|
263 |
+
'vowel_roundedness': 'rounded'
|
264 |
+
},
|
265 |
+
'p': {
|
266 |
+
'symbol_type' : 'phoneme',
|
267 |
+
'vowel_consonant' : 'consonant',
|
268 |
+
'VUV' : 'unvoiced',
|
269 |
+
'consonant_place' : 'bilabial',
|
270 |
+
'consonant_manner': 'stop'
|
271 |
+
},
|
272 |
+
'ɡ': {
|
273 |
+
'symbol_type' : 'phoneme',
|
274 |
+
'vowel_consonant' : 'consonant',
|
275 |
+
'VUV' : 'voiced',
|
276 |
+
'consonant_place' : 'velar',
|
277 |
+
'consonant_manner': 'stop'
|
278 |
+
},
|
279 |
+
'ɹ': {
|
280 |
+
'symbol_type' : 'phoneme',
|
281 |
+
'vowel_consonant' : 'consonant',
|
282 |
+
'VUV' : 'voiced',
|
283 |
+
'consonant_place' : 'alveolar',
|
284 |
+
'consonant_manner': 'approximant'
|
285 |
+
},
|
286 |
+
'r': {
|
287 |
+
'symbol_type' : 'phoneme',
|
288 |
+
'vowel_consonant' : 'consonant',
|
289 |
+
'VUV' : 'voiced',
|
290 |
+
'consonant_place' : 'alveolar',
|
291 |
+
'consonant_manner': 'trill'
|
292 |
+
},
|
293 |
+
's': {
|
294 |
+
'symbol_type' : 'phoneme',
|
295 |
+
'vowel_consonant' : 'consonant',
|
296 |
+
'VUV' : 'unvoiced',
|
297 |
+
'consonant_place' : 'alveolar',
|
298 |
+
'consonant_manner': 'fricative'
|
299 |
+
},
|
300 |
+
't': {
|
301 |
+
'symbol_type' : 'phoneme',
|
302 |
+
'vowel_consonant' : 'consonant',
|
303 |
+
'VUV' : 'unvoiced',
|
304 |
+
'consonant_place' : 'alveolar',
|
305 |
+
'consonant_manner': 'stop'
|
306 |
+
},
|
307 |
+
'u': {
|
308 |
+
'symbol_type' : 'phoneme',
|
309 |
+
'vowel_consonant' : 'vowel',
|
310 |
+
'VUV' : 'voiced',
|
311 |
+
'vowel_frontness' : 'back',
|
312 |
+
'vowel_openness' : 'close',
|
313 |
+
'vowel_roundedness': 'rounded',
|
314 |
+
},
|
315 |
+
'v': {
|
316 |
+
'symbol_type' : 'phoneme',
|
317 |
+
'vowel_consonant' : 'consonant',
|
318 |
+
'VUV' : 'voiced',
|
319 |
+
'consonant_place' : 'labiodental',
|
320 |
+
'consonant_manner': 'fricative'
|
321 |
+
},
|
322 |
+
'w': {
|
323 |
+
'symbol_type' : 'phoneme',
|
324 |
+
'vowel_consonant' : 'consonant',
|
325 |
+
'VUV' : 'voiced',
|
326 |
+
'consonant_place' : 'labial-velar',
|
327 |
+
'consonant_manner': 'approximant'
|
328 |
+
},
|
329 |
+
'x': {
|
330 |
+
'symbol_type' : 'phoneme',
|
331 |
+
'vowel_consonant' : 'consonant',
|
332 |
+
'VUV' : 'unvoiced',
|
333 |
+
'consonant_place' : 'velar',
|
334 |
+
'consonant_manner': 'fricative'
|
335 |
+
},
|
336 |
+
'z': {
|
337 |
+
'symbol_type' : 'phoneme',
|
338 |
+
'vowel_consonant' : 'consonant',
|
339 |
+
'VUV' : 'voiced',
|
340 |
+
'consonant_place' : 'alveolar',
|
341 |
+
'consonant_manner': 'fricative'
|
342 |
+
},
|
343 |
+
'ʀ': {
|
344 |
+
'symbol_type' : 'phoneme',
|
345 |
+
'vowel_consonant' : 'consonant',
|
346 |
+
'VUV' : 'voiced',
|
347 |
+
'consonant_place' : 'uvular',
|
348 |
+
'consonant_manner': 'trill'
|
349 |
+
},
|
350 |
+
'ø': {
|
351 |
+
'symbol_type' : 'phoneme',
|
352 |
+
'vowel_consonant' : 'vowel',
|
353 |
+
'VUV' : 'voiced',
|
354 |
+
'vowel_frontness' : 'front',
|
355 |
+
'vowel_openness' : 'close-mid',
|
356 |
+
'vowel_roundedness': 'rounded'
|
357 |
+
},
|
358 |
+
'ç': {
|
359 |
+
'symbol_type' : 'phoneme',
|
360 |
+
'vowel_consonant' : 'consonant',
|
361 |
+
'VUV' : 'unvoiced',
|
362 |
+
'consonant_place' : 'palatal',
|
363 |
+
'consonant_manner': 'fricative'
|
364 |
+
},
|
365 |
+
'ɐ': {
|
366 |
+
'symbol_type' : 'phoneme',
|
367 |
+
'vowel_consonant' : 'vowel',
|
368 |
+
'VUV' : 'voiced',
|
369 |
+
'vowel_frontness' : 'central',
|
370 |
+
'vowel_openness' : 'open',
|
371 |
+
'vowel_roundedness': 'unrounded'
|
372 |
+
},
|
373 |
+
'œ': {
|
374 |
+
'symbol_type' : 'phoneme',
|
375 |
+
'vowel_consonant' : 'vowel',
|
376 |
+
'VUV' : 'voiced',
|
377 |
+
'vowel_frontness' : 'front',
|
378 |
+
'vowel_openness' : 'open-mid',
|
379 |
+
'vowel_roundedness': 'rounded'
|
380 |
+
},
|
381 |
+
'y': {
|
382 |
+
'symbol_type' : 'phoneme',
|
383 |
+
'vowel_consonant' : 'vowel',
|
384 |
+
'VUV' : 'voiced',
|
385 |
+
'vowel_frontness' : 'front',
|
386 |
+
'vowel_openness' : 'close',
|
387 |
+
'vowel_roundedness': 'rounded'
|
388 |
+
},
|
389 |
+
'ʏ': {
|
390 |
+
'symbol_type' : 'phoneme',
|
391 |
+
'vowel_consonant' : 'vowel',
|
392 |
+
'VUV' : 'voiced',
|
393 |
+
'vowel_frontness' : 'front_central',
|
394 |
+
'vowel_openness' : 'close_close-mid',
|
395 |
+
'vowel_roundedness': 'rounded'
|
396 |
+
},
|
397 |
+
'ɑ': {
|
398 |
+
'symbol_type' : 'phoneme',
|
399 |
+
'vowel_consonant' : 'vowel',
|
400 |
+
'VUV' : 'voiced',
|
401 |
+
'vowel_frontness' : 'back',
|
402 |
+
'vowel_openness' : 'open',
|
403 |
+
'vowel_roundedness': 'unrounded'
|
404 |
+
},
|
405 |
+
'c': {
|
406 |
+
'symbol_type' : 'phoneme',
|
407 |
+
'vowel_consonant' : 'consonant',
|
408 |
+
'VUV' : 'unvoiced',
|
409 |
+
'consonant_place' : 'palatal',
|
410 |
+
'consonant_manner': 'stop'
|
411 |
+
},
|
412 |
+
'ɲ': {
|
413 |
+
'symbol_type' : 'phoneme',
|
414 |
+
'vowel_consonant' : 'consonant',
|
415 |
+
'VUV' : 'voiced',
|
416 |
+
'consonant_place' : 'palatal',
|
417 |
+
'consonant_manner': 'nasal'
|
418 |
+
},
|
419 |
+
'ɣ': {
|
420 |
+
'symbol_type' : 'phoneme',
|
421 |
+
'vowel_consonant' : 'consonant',
|
422 |
+
'VUV' : 'voiced',
|
423 |
+
'consonant_place' : 'velar',
|
424 |
+
'consonant_manner': 'fricative'
|
425 |
+
},
|
426 |
+
'ʎ': {
|
427 |
+
'symbol_type' : 'phoneme',
|
428 |
+
'vowel_consonant' : 'consonant',
|
429 |
+
'VUV' : 'voiced',
|
430 |
+
'consonant_place' : 'palatal',
|
431 |
+
'consonant_manner': 'lateral-approximant'
|
432 |
+
},
|
433 |
+
'β': {
|
434 |
+
'symbol_type' : 'phoneme',
|
435 |
+
'vowel_consonant' : 'consonant',
|
436 |
+
'VUV' : 'voiced',
|
437 |
+
'consonant_place' : 'bilabial',
|
438 |
+
'consonant_manner': 'fricative'
|
439 |
+
},
|
440 |
+
'ʝ': {
|
441 |
+
'symbol_type' : 'phoneme',
|
442 |
+
'vowel_consonant' : 'consonant',
|
443 |
+
'VUV' : 'voiced',
|
444 |
+
'consonant_place' : 'palatal',
|
445 |
+
'consonant_manner': 'fricative'
|
446 |
+
},
|
447 |
+
'ɟ': {
|
448 |
+
'symbol_type' : 'phoneme',
|
449 |
+
'vowel_consonant' : 'consonant',
|
450 |
+
'VUV' : 'voiced',
|
451 |
+
'consonant_place' : 'palatal',
|
452 |
+
'consonant_manner': 'stop'
|
453 |
+
},
|
454 |
+
'q': {
|
455 |
+
'symbol_type' : 'phoneme',
|
456 |
+
'vowel_consonant' : 'consonant',
|
457 |
+
'VUV' : 'unvoiced',
|
458 |
+
'consonant_place' : 'uvular',
|
459 |
+
'consonant_manner': 'stop'
|
460 |
+
},
|
461 |
+
'ɕ': {
|
462 |
+
'symbol_type' : 'phoneme',
|
463 |
+
'vowel_consonant' : 'consonant',
|
464 |
+
'VUV' : 'unvoiced',
|
465 |
+
'consonant_place' : 'alveolopalatal',
|
466 |
+
'consonant_manner': 'fricative'
|
467 |
+
},
|
468 |
+
'ʲ': {
|
469 |
+
'symbol_type' : 'phoneme',
|
470 |
+
'vowel_consonant' : 'consonant',
|
471 |
+
'VUV' : 'voiced',
|
472 |
+
'consonant_place' : 'palatal',
|
473 |
+
'consonant_manner': 'approximant'
|
474 |
+
},
|
475 |
+
'ɭ': {
|
476 |
+
'symbol_type' : 'phoneme',
|
477 |
+
'vowel_consonant' : 'consonant',
|
478 |
+
'VUV' : 'voiced',
|
479 |
+
'consonant_place' : 'palatal', # should be retroflex, but palatal should be close enough
|
480 |
+
'consonant_manner': 'lateral-approximant'
|
481 |
+
},
|
482 |
+
'ɵ': {
|
483 |
+
'symbol_type' : 'phoneme',
|
484 |
+
'vowel_consonant' : 'vowel',
|
485 |
+
'VUV' : 'voiced',
|
486 |
+
'vowel_frontness' : 'central',
|
487 |
+
'vowel_openness' : 'open-mid',
|
488 |
+
'vowel_roundedness': 'rounded'
|
489 |
+
},
|
490 |
+
'ʑ': {
|
491 |
+
'symbol_type' : 'phoneme',
|
492 |
+
'vowel_consonant' : 'consonant',
|
493 |
+
'VUV' : 'voiced',
|
494 |
+
'consonant_place' : 'alveolopalatal',
|
495 |
+
'consonant_manner': 'fricative'
|
496 |
+
},
|
497 |
+
'ʋ': {
|
498 |
+
'symbol_type' : 'phoneme',
|
499 |
+
'vowel_consonant' : 'consonant',
|
500 |
+
'VUV' : 'voiced',
|
501 |
+
'consonant_place' : 'labiodental',
|
502 |
+
'consonant_manner': 'approximant'
|
503 |
+
},
|
504 |
+
'ʁ': {
|
505 |
+
'symbol_type' : 'phoneme',
|
506 |
+
'vowel_consonant' : 'consonant',
|
507 |
+
'VUV' : 'voiced',
|
508 |
+
'consonant_place' : 'uvular',
|
509 |
+
'consonant_manner': 'fricative'
|
510 |
+
},
|
511 |
+
'ɨ': {
|
512 |
+
'symbol_type' : 'phoneme',
|
513 |
+
'vowel_consonant' : 'vowel',
|
514 |
+
'VUV' : 'voiced',
|
515 |
+
'vowel_frontness' : 'central',
|
516 |
+
'vowel_openness' : 'close',
|
517 |
+
'vowel_roundedness': 'unrounded'
|
518 |
+
},
|
519 |
+
'ʂ': {
|
520 |
+
'symbol_type' : 'phoneme',
|
521 |
+
'vowel_consonant' : 'consonant',
|
522 |
+
'VUV' : 'unvoiced',
|
523 |
+
'consonant_place' : 'palatal', # should be retroflex, but palatal should be close enough
|
524 |
+
'consonant_manner': 'fricative'
|
525 |
+
},
|
526 |
+
'ɬ': {
|
527 |
+
'symbol_type' : 'phoneme',
|
528 |
+
'vowel_consonant' : 'consonant',
|
529 |
+
'VUV' : 'unvoiced',
|
530 |
+
'consonant_place' : 'alveolar', # should be noted it's also lateral, but should be close enough
|
531 |
+
'consonant_manner': 'fricative'
|
532 |
+
},
|
533 |
+
} # REMEMBER to also add the phonemes added here to the ID lookup table in the TextFrontend as the new highest ID
|
534 |
+
|
535 |
+
|
536 |
+
def generate_feature_table():
|
537 |
+
ipa_to_phonemefeats = generate_feature_lookup()
|
538 |
+
|
539 |
+
feat_types = set()
|
540 |
+
for ipa in ipa_to_phonemefeats:
|
541 |
+
if len(ipa) == 1:
|
542 |
+
[feat_types.add(feat) for feat in ipa_to_phonemefeats[ipa].keys()]
|
543 |
+
|
544 |
+
feat_to_val_set = dict()
|
545 |
+
for feat in feat_types:
|
546 |
+
feat_to_val_set[feat] = set()
|
547 |
+
for ipa in ipa_to_phonemefeats:
|
548 |
+
if len(ipa) == 1:
|
549 |
+
for feat in ipa_to_phonemefeats[ipa]:
|
550 |
+
feat_to_val_set[feat].add(ipa_to_phonemefeats[ipa][feat])
|
551 |
+
|
552 |
+
# print(feat_to_val_set)
|
553 |
+
|
554 |
+
value_list = set()
|
555 |
+
for val_set in [feat_to_val_set[feat] for feat in feat_to_val_set]:
|
556 |
+
for value in val_set:
|
557 |
+
value_list.add(value)
|
558 |
+
# print("{")
|
559 |
+
# for index, value in enumerate(list(value_list)):
|
560 |
+
# print('"{}":{},'.format(value,index))
|
561 |
+
# print("}")
|
562 |
+
|
563 |
+
value_to_index = {
|
564 |
+
"dental" : 0,
|
565 |
+
"postalveolar" : 1,
|
566 |
+
"mid" : 2,
|
567 |
+
"close-mid" : 3,
|
568 |
+
"vowel" : 4,
|
569 |
+
"silence" : 5,
|
570 |
+
"consonant" : 6,
|
571 |
+
"close" : 7,
|
572 |
+
"velar" : 8,
|
573 |
+
"stop" : 9,
|
574 |
+
"palatal" : 10,
|
575 |
+
"nasal" : 11,
|
576 |
+
"glottal" : 12,
|
577 |
+
"central" : 13,
|
578 |
+
"back" : 14,
|
579 |
+
"approximant" : 15,
|
580 |
+
"uvular" : 16,
|
581 |
+
"open-mid" : 17,
|
582 |
+
"front_central" : 18,
|
583 |
+
"front" : 19,
|
584 |
+
"end of sentence" : 20,
|
585 |
+
"labiodental" : 21,
|
586 |
+
"close_close-mid" : 22,
|
587 |
+
"labial-velar" : 23,
|
588 |
+
"unvoiced" : 24,
|
589 |
+
"central_back" : 25,
|
590 |
+
"trill" : 26,
|
591 |
+
"rounded" : 27,
|
592 |
+
"open-mid_open" : 28,
|
593 |
+
"tap" : 29,
|
594 |
+
"alveolar" : 30,
|
595 |
+
"bilabial" : 31,
|
596 |
+
"phoneme" : 32,
|
597 |
+
"open" : 33,
|
598 |
+
"fricative" : 34,
|
599 |
+
"unrounded" : 35,
|
600 |
+
"lateral-approximant": 36,
|
601 |
+
"voiced" : 37,
|
602 |
+
"questionmark" : 38,
|
603 |
+
"exclamationmark" : 39,
|
604 |
+
"fullstop" : 40,
|
605 |
+
"alveolopalatal" : 41
|
606 |
+
}
|
607 |
+
|
608 |
+
phone_to_vector = dict()
|
609 |
+
for ipa in ipa_to_phonemefeats:
|
610 |
+
if len(ipa) == 1:
|
611 |
+
phone_to_vector[ipa] = [0] * sum([len(values) for values in [feat_to_val_set[feat] for feat in feat_to_val_set]])
|
612 |
+
for feat in ipa_to_phonemefeats[ipa]:
|
613 |
+
if ipa_to_phonemefeats[ipa][feat] in value_to_index:
|
614 |
+
phone_to_vector[ipa][value_to_index[ipa_to_phonemefeats[ipa][feat]]] = 1
|
615 |
+
|
616 |
+
for feat in feat_to_val_set:
|
617 |
+
for value in feat_to_val_set[feat]:
|
618 |
+
if value not in value_to_index:
|
619 |
+
print(f"Unknown feature value in featureset! {value}")
|
620 |
+
|
621 |
+
# print(f"{sum([len(values) for values in [feat_to_val_set[feat] for feat in feat_to_val_set]])} should be 42")
|
622 |
+
|
623 |
+
return phone_to_vector
|
624 |
+
|
625 |
+
|
626 |
+
def generate_phone_to_id_lookup():
|
627 |
+
ipa_to_phonemefeats = generate_feature_lookup()
|
628 |
+
count = 0
|
629 |
+
phone_to_id = dict()
|
630 |
+
for key in sorted(list(ipa_to_phonemefeats)): # careful: non-deterministic
|
631 |
+
phone_to_id[key] = count
|
632 |
+
count += 1
|
633 |
+
return phone_to_id
|
634 |
+
|
635 |
+
|
636 |
+
if __name__ == '__main__':
|
637 |
+
print(generate_phone_to_id_lookup())
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: SpeechCloning
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
|
|
1 |
---
|
2 |
title: SpeechCloning
|
3 |
+
emoji: 🦜
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/Aligner.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken and adapted from https://github.com/as-ideas/DeepForcedAligner
|
3 |
+
"""
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.multiprocessing
|
8 |
+
import torch.nn as nn
|
9 |
+
from scipy.sparse import coo_matrix
|
10 |
+
from scipy.sparse.csgraph import dijkstra
|
11 |
+
from torch.nn import CTCLoss
|
12 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
13 |
+
from torch.nn.utils.rnn import pad_packed_sequence
|
14 |
+
|
15 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
|
16 |
+
|
17 |
+
|
18 |
+
class BatchNormConv(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
|
21 |
+
super().__init__()
|
22 |
+
self.conv = nn.Conv1d(
|
23 |
+
in_channels, out_channels, kernel_size,
|
24 |
+
stride=1, padding=kernel_size // 2, bias=False)
|
25 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
26 |
+
self.relu = nn.ReLU()
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = x.transpose(1, 2)
|
30 |
+
x = self.conv(x)
|
31 |
+
x = self.relu(x)
|
32 |
+
x = self.bnorm(x)
|
33 |
+
x = x.transpose(1, 2)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class Aligner(torch.nn.Module):
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
n_mels=80,
|
41 |
+
num_symbols=145,
|
42 |
+
lstm_dim=512,
|
43 |
+
conv_dim=512):
|
44 |
+
super().__init__()
|
45 |
+
self.convs = nn.ModuleList([
|
46 |
+
BatchNormConv(n_mels, conv_dim, 3),
|
47 |
+
nn.Dropout(p=0.5),
|
48 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
49 |
+
nn.Dropout(p=0.5),
|
50 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
51 |
+
nn.Dropout(p=0.5),
|
52 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
53 |
+
nn.Dropout(p=0.5),
|
54 |
+
BatchNormConv(conv_dim, conv_dim, 3),
|
55 |
+
nn.Dropout(p=0.5),
|
56 |
+
])
|
57 |
+
self.rnn = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
|
58 |
+
self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
|
59 |
+
self.tf = ArticulatoryCombinedTextFrontend(language="en")
|
60 |
+
self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
|
61 |
+
self.vector_to_id = dict()
|
62 |
+
for phone in self.tf.phone_to_vector:
|
63 |
+
self.vector_to_id[tuple(self.tf.phone_to_vector[phone])] = self.tf.phone_to_id[phone]
|
64 |
+
|
65 |
+
def forward(self, x, lens=None):
|
66 |
+
for conv in self.convs:
|
67 |
+
x = conv(x)
|
68 |
+
|
69 |
+
if lens is not None:
|
70 |
+
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
|
71 |
+
x, _ = self.rnn(x)
|
72 |
+
if lens is not None:
|
73 |
+
x, _ = pad_packed_sequence(x, batch_first=True)
|
74 |
+
|
75 |
+
x = self.proj(x)
|
76 |
+
|
77 |
+
return x
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def label_speech(self, speech):
|
81 |
+
# theoretically possible, but doesn't work well at all. Would probably require a beamsearch
|
82 |
+
probabilities_of_phones_over_frames = self(speech.unsqueeze(0)).squeeze()[:, :73]
|
83 |
+
smoothed_phone_probs_over_frames = list()
|
84 |
+
for index, _ in enumerate(probabilities_of_phones_over_frames):
|
85 |
+
access_safe_prev_index = max(0, index - 1)
|
86 |
+
access_safe_next_index = min(index + 1, len(probabilities_of_phones_over_frames) - 1)
|
87 |
+
smoothed_probs = (probabilities_of_phones_over_frames[access_safe_prev_index] +
|
88 |
+
probabilities_of_phones_over_frames[access_safe_next_index] +
|
89 |
+
probabilities_of_phones_over_frames[index]) / 3
|
90 |
+
smoothed_phone_probs_over_frames.append(smoothed_probs.unsqueeze(0))
|
91 |
+
print(torch.cat(smoothed_phone_probs_over_frames))
|
92 |
+
_, phone_ids_over_frames = torch.max(torch.cat(smoothed_phone_probs_over_frames), dim=1)
|
93 |
+
phone_ids = torch.unique_consecutive(phone_ids_over_frames)
|
94 |
+
phones = list()
|
95 |
+
for id_of_phone in phone_ids:
|
96 |
+
phones.append(self.tf.id_to_phone[int(id_of_phone)])
|
97 |
+
return "".join(phones)
|
98 |
+
|
99 |
+
@torch.inference_mode()
|
100 |
+
def inference(self, mel, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False):
|
101 |
+
if not train:
|
102 |
+
tokens_indexed = list() # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi
|
103 |
+
for vector in tokens:
|
104 |
+
tokens_indexed.append(self.vector_to_id[tuple(vector.cpu().detach().numpy().tolist())])
|
105 |
+
tokens = np.asarray(tokens_indexed)
|
106 |
+
else:
|
107 |
+
tokens = tokens.cpu().detach().numpy()
|
108 |
+
|
109 |
+
pred = self(mel.unsqueeze(0))
|
110 |
+
if return_ctc:
|
111 |
+
ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]),
|
112 |
+
torch.LongTensor([len(tokens)])).item()
|
113 |
+
pred = pred.squeeze().cpu().detach().numpy()
|
114 |
+
pred_max = pred[:, tokens]
|
115 |
+
path_probs = 1. - pred_max
|
116 |
+
adj_matrix = to_adj_matrix(path_probs)
|
117 |
+
|
118 |
+
if pathfinding == "MAS":
|
119 |
+
|
120 |
+
alignment_matrix = binarize_alignment(pred_max)
|
121 |
+
|
122 |
+
if save_img_for_debug is not None:
|
123 |
+
phones = list()
|
124 |
+
for index in tokens:
|
125 |
+
for phone in self.tf.phone_to_id:
|
126 |
+
if self.tf.phone_to_id[phone] == index:
|
127 |
+
phones.append(phone)
|
128 |
+
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 4))
|
129 |
+
|
130 |
+
ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
|
131 |
+
|
132 |
+
ax.set_ylabel("Mel-Frames")
|
133 |
+
|
134 |
+
ax.set_xticks(range(len(pred_max[0])))
|
135 |
+
ax.set_xticklabels(labels=phones)
|
136 |
+
|
137 |
+
ax.set_title("MAS Path")
|
138 |
+
|
139 |
+
plt.tight_layout()
|
140 |
+
fig.savefig(save_img_for_debug)
|
141 |
+
fig.clf()
|
142 |
+
plt.close()
|
143 |
+
|
144 |
+
if return_ctc:
|
145 |
+
return alignment_matrix, ctc_loss
|
146 |
+
return alignment_matrix
|
147 |
+
|
148 |
+
elif pathfinding == "dijkstra":
|
149 |
+
|
150 |
+
dist_matrix, predecessors, *_ = dijkstra(csgraph=adj_matrix,
|
151 |
+
directed=True,
|
152 |
+
indices=0,
|
153 |
+
return_predecessors=True)
|
154 |
+
path = []
|
155 |
+
pr_index = predecessors[-1]
|
156 |
+
while pr_index != 0:
|
157 |
+
path.append(pr_index)
|
158 |
+
pr_index = predecessors[pr_index]
|
159 |
+
path.reverse()
|
160 |
+
|
161 |
+
# append first and last node
|
162 |
+
path = [0] + path + [dist_matrix.size - 1]
|
163 |
+
cols = path_probs.shape[1]
|
164 |
+
mel_text = {}
|
165 |
+
|
166 |
+
# collect indices (mel, text) along the path
|
167 |
+
for node_index in path:
|
168 |
+
i, j = from_node_index(node_index, cols)
|
169 |
+
mel_text[i] = j
|
170 |
+
|
171 |
+
path_plot = np.zeros_like(pred_max)
|
172 |
+
for i in mel_text:
|
173 |
+
path_plot[i][mel_text[i]] = 1.0
|
174 |
+
|
175 |
+
if save_img_for_debug is not None:
|
176 |
+
|
177 |
+
phones = list()
|
178 |
+
for index in tokens:
|
179 |
+
for phone in self.tf.phone_to_id:
|
180 |
+
if self.tf.phone_to_id[phone] == index:
|
181 |
+
phones.append(phone)
|
182 |
+
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(10, 9))
|
183 |
+
|
184 |
+
ax[0].imshow(pred_max, interpolation='nearest', aspect='auto', origin="lower")
|
185 |
+
ax[1].imshow(path_plot, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
|
186 |
+
|
187 |
+
ax[0].set_ylabel("Mel-Frames")
|
188 |
+
ax[1].set_ylabel("Mel-Frames")
|
189 |
+
|
190 |
+
ax[0].set_xticks(range(len(pred_max[0])))
|
191 |
+
ax[0].set_xticklabels(labels=phones)
|
192 |
+
|
193 |
+
ax[1].set_xticks(range(len(pred_max[0])))
|
194 |
+
ax[1].set_xticklabels(labels=phones)
|
195 |
+
|
196 |
+
ax[0].set_title("Path Probabilities")
|
197 |
+
ax[1].set_title("Dijkstra Path")
|
198 |
+
|
199 |
+
plt.tight_layout()
|
200 |
+
fig.savefig(save_img_for_debug)
|
201 |
+
fig.clf()
|
202 |
+
plt.close()
|
203 |
+
|
204 |
+
if return_ctc:
|
205 |
+
return path_plot, ctc_loss
|
206 |
+
return path_plot
|
207 |
+
|
208 |
+
|
209 |
+
def binarize_alignment(alignment_prob):
|
210 |
+
"""
|
211 |
+
# Implementation by:
|
212 |
+
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py
|
213 |
+
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py
|
214 |
+
|
215 |
+
Binarizes alignment with MAS.
|
216 |
+
"""
|
217 |
+
# assumes mel x text
|
218 |
+
opt = np.zeros_like(alignment_prob)
|
219 |
+
alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later
|
220 |
+
alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1]
|
221 |
+
attn_map = np.log(alignment_prob)
|
222 |
+
attn_map[0, 1:] = -np.inf
|
223 |
+
log_p = np.zeros_like(attn_map)
|
224 |
+
log_p[0, :] = attn_map[0, :]
|
225 |
+
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
226 |
+
for i in range(1, attn_map.shape[0]):
|
227 |
+
for j in range(attn_map.shape[1]): # for each text dim
|
228 |
+
prev_log = log_p[i - 1, j]
|
229 |
+
prev_j = j
|
230 |
+
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
|
231 |
+
prev_log = log_p[i - 1, j - 1]
|
232 |
+
prev_j = j - 1
|
233 |
+
log_p[i, j] = attn_map[i, j] + prev_log
|
234 |
+
prev_ind[i, j] = prev_j
|
235 |
+
# now backtrack
|
236 |
+
curr_text_idx = attn_map.shape[1] - 1
|
237 |
+
for i in range(attn_map.shape[0] - 1, -1, -1):
|
238 |
+
opt[i, curr_text_idx] = 1
|
239 |
+
curr_text_idx = prev_ind[i, curr_text_idx]
|
240 |
+
opt[0, curr_text_idx] = 1
|
241 |
+
return opt
|
242 |
+
|
243 |
+
|
244 |
+
def to_node_index(i, j, cols):
|
245 |
+
return cols * i + j
|
246 |
+
|
247 |
+
|
248 |
+
def from_node_index(node_index, cols):
|
249 |
+
return node_index // cols, node_index % cols
|
250 |
+
|
251 |
+
|
252 |
+
def to_adj_matrix(mat):
|
253 |
+
rows = mat.shape[0]
|
254 |
+
cols = mat.shape[1]
|
255 |
+
|
256 |
+
row_ind = []
|
257 |
+
col_ind = []
|
258 |
+
data = []
|
259 |
+
|
260 |
+
for i in range(rows):
|
261 |
+
for j in range(cols):
|
262 |
+
|
263 |
+
node = to_node_index(i, j, cols)
|
264 |
+
|
265 |
+
if j < cols - 1:
|
266 |
+
right_node = to_node_index(i, j + 1, cols)
|
267 |
+
weight_right = mat[i, j + 1]
|
268 |
+
row_ind.append(node)
|
269 |
+
col_ind.append(right_node)
|
270 |
+
data.append(weight_right)
|
271 |
+
|
272 |
+
if i < rows - 1 and j < cols:
|
273 |
+
bottom_node = to_node_index(i + 1, j, cols)
|
274 |
+
weight_bottom = mat[i + 1, j]
|
275 |
+
row_ind.append(node)
|
276 |
+
col_ind.append(bottom_node)
|
277 |
+
data.append(weight_bottom)
|
278 |
+
|
279 |
+
if i < rows - 1 and j < cols - 1:
|
280 |
+
bottom_right_node = to_node_index(i + 1, j + 1, cols)
|
281 |
+
weight_bottom_right = mat[i + 1, j + 1]
|
282 |
+
row_ind.append(node)
|
283 |
+
col_ind.append(bottom_right_node)
|
284 |
+
data.append(weight_bottom_right)
|
285 |
+
|
286 |
+
adj_mat = coo_matrix((data, (row_ind, col_ind)), shape=(rows * cols, rows * cols))
|
287 |
+
return adj_mat.tocsr()
|
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/AlignerDataset.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import soundfile as sf
|
6 |
+
import torch
|
7 |
+
from numpy import trim_zeros
|
8 |
+
from speechbrain.pretrained import EncoderClassifier
|
9 |
+
from torch.multiprocessing import Manager
|
10 |
+
from torch.multiprocessing import Process
|
11 |
+
from torch.multiprocessing import set_start_method
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
|
16 |
+
from Preprocessing.AudioPreprocessor import AudioPreprocessor
|
17 |
+
|
18 |
+
|
19 |
+
class AlignerDataset(Dataset):
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
path_to_transcript_dict,
|
23 |
+
cache_dir,
|
24 |
+
lang,
|
25 |
+
loading_processes=30, # careful with the amount of processes if you use silence removal, only as many processes as you have cores
|
26 |
+
min_len_in_seconds=1,
|
27 |
+
max_len_in_seconds=20,
|
28 |
+
cut_silences=False,
|
29 |
+
rebuild_cache=False,
|
30 |
+
verbose=False,
|
31 |
+
device="cpu"):
|
32 |
+
os.makedirs(cache_dir, exist_ok=True)
|
33 |
+
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
|
34 |
+
if (device == "cuda" or device == torch.device("cuda")) and cut_silences:
|
35 |
+
try:
|
36 |
+
set_start_method('spawn') # in order to be able to make use of cuda in multiprocessing
|
37 |
+
except RuntimeError:
|
38 |
+
pass
|
39 |
+
elif cut_silences:
|
40 |
+
torch.set_num_threads(1)
|
41 |
+
if cut_silences:
|
42 |
+
torch.hub.load(repo_or_dir='snakers4/silero-vad',
|
43 |
+
model='silero_vad',
|
44 |
+
force_reload=False,
|
45 |
+
onnx=False,
|
46 |
+
verbose=False) # download and cache for it to be loaded and used later
|
47 |
+
torch.set_grad_enabled(True)
|
48 |
+
resource_manager = Manager()
|
49 |
+
self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
|
50 |
+
key_list = list(self.path_to_transcript_dict.keys())
|
51 |
+
with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note:
|
52 |
+
files_used_note.write(str(key_list))
|
53 |
+
random.shuffle(key_list)
|
54 |
+
# build cache
|
55 |
+
print("... building dataset cache ...")
|
56 |
+
self.datapoints = resource_manager.list()
|
57 |
+
# make processes
|
58 |
+
key_splits = list()
|
59 |
+
process_list = list()
|
60 |
+
for i in range(loading_processes):
|
61 |
+
key_splits.append(key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes])
|
62 |
+
for key_split in key_splits:
|
63 |
+
process_list.append(
|
64 |
+
Process(target=self.cache_builder_process,
|
65 |
+
args=(key_split,
|
66 |
+
lang,
|
67 |
+
min_len_in_seconds,
|
68 |
+
max_len_in_seconds,
|
69 |
+
cut_silences,
|
70 |
+
verbose,
|
71 |
+
device),
|
72 |
+
daemon=True))
|
73 |
+
process_list[-1].start()
|
74 |
+
for process in process_list:
|
75 |
+
process.join()
|
76 |
+
self.datapoints = list(self.datapoints)
|
77 |
+
tensored_datapoints = list()
|
78 |
+
# we had to turn all of the tensors to numpy arrays to avoid shared memory
|
79 |
+
# issues. Now that the multi-processing is over, we can convert them back
|
80 |
+
# to tensors to save on conversions in the future.
|
81 |
+
print("Converting into convenient format...")
|
82 |
+
norm_waves = list()
|
83 |
+
for datapoint in tqdm(self.datapoints):
|
84 |
+
tensored_datapoints.append([torch.Tensor(datapoint[0]),
|
85 |
+
torch.LongTensor(datapoint[1]),
|
86 |
+
torch.Tensor(datapoint[2]),
|
87 |
+
torch.LongTensor(datapoint[3])])
|
88 |
+
norm_waves.append(torch.Tensor(datapoint[-1]))
|
89 |
+
|
90 |
+
self.datapoints = tensored_datapoints
|
91 |
+
|
92 |
+
pop_indexes = list()
|
93 |
+
for index, el in enumerate(self.datapoints):
|
94 |
+
try:
|
95 |
+
if len(el[0][0]) != 66:
|
96 |
+
pop_indexes.append(index)
|
97 |
+
except TypeError:
|
98 |
+
pop_indexes.append(index)
|
99 |
+
for pop_index in sorted(pop_indexes, reverse=True):
|
100 |
+
print(f"There seems to be a problem in the transcriptions. Deleting datapoint {pop_index}.")
|
101 |
+
self.datapoints.pop(pop_index)
|
102 |
+
|
103 |
+
# add speaker embeddings
|
104 |
+
self.speaker_embeddings = list()
|
105 |
+
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
|
106 |
+
run_opts={"device": str(device)},
|
107 |
+
savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
|
108 |
+
with torch.no_grad():
|
109 |
+
for wave in tqdm(norm_waves):
|
110 |
+
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
|
111 |
+
|
112 |
+
# save to cache
|
113 |
+
torch.save((self.datapoints, norm_waves, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt"))
|
114 |
+
else:
|
115 |
+
# just load the datapoints from cache
|
116 |
+
self.datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
|
117 |
+
if len(self.datapoints) == 2:
|
118 |
+
# speaker embeddings are still missing, have to add them here
|
119 |
+
wave_datapoints = self.datapoints[1]
|
120 |
+
self.datapoints = self.datapoints[0]
|
121 |
+
self.speaker_embeddings = list()
|
122 |
+
speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
|
123 |
+
run_opts={"device": str(device)},
|
124 |
+
savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
|
125 |
+
with torch.no_grad():
|
126 |
+
for wave in tqdm(wave_datapoints):
|
127 |
+
self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
|
128 |
+
torch.save((self.datapoints, wave_datapoints, self.speaker_embeddings), os.path.join(cache_dir, "aligner_train_cache.pt"))
|
129 |
+
else:
|
130 |
+
self.speaker_embeddings = self.datapoints[2]
|
131 |
+
self.datapoints = self.datapoints[0]
|
132 |
+
|
133 |
+
self.tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=True)
|
134 |
+
print(f"Prepared an Aligner dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
|
135 |
+
|
136 |
+
def cache_builder_process(self,
|
137 |
+
path_list,
|
138 |
+
lang,
|
139 |
+
min_len,
|
140 |
+
max_len,
|
141 |
+
cut_silences,
|
142 |
+
verbose,
|
143 |
+
device):
|
144 |
+
process_internal_dataset_chunk = list()
|
145 |
+
tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False)
|
146 |
+
_, sr = sf.read(path_list[0])
|
147 |
+
ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=cut_silences, device=device)
|
148 |
+
|
149 |
+
for path in tqdm(path_list):
|
150 |
+
if self.path_to_transcript_dict[path].strip() == "":
|
151 |
+
continue
|
152 |
+
|
153 |
+
wave, sr = sf.read(path)
|
154 |
+
dur_in_seconds = len(wave) / sr
|
155 |
+
if not (min_len <= dur_in_seconds <= max_len):
|
156 |
+
if verbose:
|
157 |
+
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
|
158 |
+
continue
|
159 |
+
try:
|
160 |
+
with warnings.catch_warnings():
|
161 |
+
warnings.simplefilter("ignore") # otherwise we get tons of warnings about an RNN not being in contiguous chunks
|
162 |
+
norm_wave = ap.audio_to_wave_tensor(normalize=True, audio=wave)
|
163 |
+
except ValueError:
|
164 |
+
continue
|
165 |
+
dur_in_seconds = len(norm_wave) / 16000
|
166 |
+
if not (min_len <= dur_in_seconds <= max_len):
|
167 |
+
if verbose:
|
168 |
+
print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
|
169 |
+
continue
|
170 |
+
norm_wave = torch.tensor(trim_zeros(norm_wave.numpy()))
|
171 |
+
# raw audio preprocessing is done
|
172 |
+
transcript = self.path_to_transcript_dict[path]
|
173 |
+
try:
|
174 |
+
cached_text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0).cpu().numpy()
|
175 |
+
except KeyError:
|
176 |
+
tf.string_to_tensor(transcript, handle_missing=True).squeeze(0).cpu().numpy()
|
177 |
+
continue # we skip sentences with unknown symbols
|
178 |
+
try:
|
179 |
+
if len(cached_text[0]) != 66:
|
180 |
+
print(f"There seems to be a problem with the following transcription: {transcript}")
|
181 |
+
continue
|
182 |
+
except TypeError:
|
183 |
+
print(f"There seems to be a problem with the following transcription: {transcript}")
|
184 |
+
continue
|
185 |
+
cached_text_len = torch.LongTensor([len(cached_text)]).numpy()
|
186 |
+
cached_speech = ap.audio_to_mel_spec_tensor(audio=norm_wave, normalize=False, explicit_sampling_rate=16000).transpose(0, 1).cpu().numpy()
|
187 |
+
cached_speech_len = torch.LongTensor([len(cached_speech)]).numpy()
|
188 |
+
process_internal_dataset_chunk.append([cached_text,
|
189 |
+
cached_text_len,
|
190 |
+
cached_speech,
|
191 |
+
cached_speech_len,
|
192 |
+
norm_wave.cpu().detach().numpy()])
|
193 |
+
self.datapoints += process_internal_dataset_chunk
|
194 |
+
|
195 |
+
def __getitem__(self, index):
|
196 |
+
text_vector = self.datapoints[index][0]
|
197 |
+
tokens = list()
|
198 |
+
for vector in text_vector:
|
199 |
+
for phone in self.tf.phone_to_vector:
|
200 |
+
if vector.numpy().tolist() == self.tf.phone_to_vector[phone]:
|
201 |
+
tokens.append(self.tf.phone_to_id[phone])
|
202 |
+
# this is terribly inefficient, but it's good enough for testing for now.
|
203 |
+
tokens = torch.LongTensor(tokens)
|
204 |
+
return tokens, \
|
205 |
+
self.datapoints[index][1], \
|
206 |
+
self.datapoints[index][2], \
|
207 |
+
self.datapoints[index][3], \
|
208 |
+
self.speaker_embeddings[index]
|
209 |
+
|
210 |
+
def __len__(self):
|
211 |
+
return len(self.datapoints)
|
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/TinyTTS.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.multiprocessing
|
3 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
4 |
+
from torch.nn.utils.rnn import pad_packed_sequence
|
5 |
+
|
6 |
+
from Utility.utils import make_non_pad_mask
|
7 |
+
|
8 |
+
|
9 |
+
class TinyTTS(torch.nn.Module):
|
10 |
+
|
11 |
+
def __init__(self,
|
12 |
+
n_mels=80,
|
13 |
+
num_symbols=145,
|
14 |
+
speaker_embedding_dim=192,
|
15 |
+
lstm_dim=512):
|
16 |
+
super().__init__()
|
17 |
+
self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, lstm_dim)
|
18 |
+
self.rnn1 = torch.nn.LSTM(lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
|
19 |
+
self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
|
20 |
+
self.out_proj = torch.nn.Linear(2 * lstm_dim, n_mels)
|
21 |
+
self.l1_criterion = torch.nn.L1Loss(reduction="none")
|
22 |
+
self.l2_criterion = torch.nn.MSELoss(reduction="none")
|
23 |
+
|
24 |
+
def forward(self, x, lens, ys):
|
25 |
+
x = self.in_proj(x)
|
26 |
+
x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
|
27 |
+
x, _ = self.rnn1(x)
|
28 |
+
x, _ = self.rnn2(x)
|
29 |
+
x, _ = pad_packed_sequence(x, batch_first=True)
|
30 |
+
x = self.out_proj(x)
|
31 |
+
out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
|
32 |
+
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
|
33 |
+
out_weights /= ys.size(0) * ys.size(2)
|
34 |
+
l1_loss = self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
|
35 |
+
l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
|
36 |
+
return l1_loss + l2_loss
|
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/__init__.py
ADDED
File without changes
|
TrainingInterfaces/Text_to_Spectrogram/AutoAligner/autoaligner_train_loop.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.multiprocessing
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
from torch.optim import RAdam
|
8 |
+
from torch.utils.data.dataloader import DataLoader
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
|
12 |
+
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.TinyTTS import TinyTTS
|
13 |
+
|
14 |
+
|
15 |
+
def collate_and_pad(batch):
|
16 |
+
# text, text_len, speech, speech_len
|
17 |
+
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
|
18 |
+
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
|
19 |
+
pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
|
20 |
+
torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
|
21 |
+
torch.stack([datapoint[4] for datapoint in batch]).squeeze())
|
22 |
+
|
23 |
+
|
24 |
+
def train_loop(train_dataset,
|
25 |
+
device,
|
26 |
+
save_directory,
|
27 |
+
batch_size,
|
28 |
+
steps,
|
29 |
+
path_to_checkpoint=None,
|
30 |
+
fine_tune=False,
|
31 |
+
resume=False,
|
32 |
+
debug_img_path=None,
|
33 |
+
use_reconstruction=True):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
resume: whether to resume from the most recent checkpoint
|
37 |
+
steps: How many steps to train
|
38 |
+
path_to_checkpoint: reloads a checkpoint to continue training from there
|
39 |
+
fine_tune: whether to load everything from a checkpoint, or only the model parameters
|
40 |
+
train_dataset: Pytorch Dataset Object for train data
|
41 |
+
device: Device to put the loaded tensors on
|
42 |
+
save_directory: Where to save the checkpoints
|
43 |
+
batch_size: How many elements should be loaded at once
|
44 |
+
"""
|
45 |
+
os.makedirs(save_directory, exist_ok=True)
|
46 |
+
train_loader = DataLoader(batch_size=batch_size,
|
47 |
+
dataset=train_dataset,
|
48 |
+
drop_last=True,
|
49 |
+
num_workers=8,
|
50 |
+
pin_memory=False,
|
51 |
+
shuffle=True,
|
52 |
+
prefetch_factor=16,
|
53 |
+
collate_fn=collate_and_pad,
|
54 |
+
persistent_workers=True)
|
55 |
+
|
56 |
+
asr_model = Aligner().to(device)
|
57 |
+
optim_asr = RAdam(asr_model.parameters(), lr=0.0001)
|
58 |
+
|
59 |
+
tiny_tts = TinyTTS().to(device)
|
60 |
+
optim_tts = RAdam(tiny_tts.parameters(), lr=0.0001)
|
61 |
+
|
62 |
+
step_counter = 0
|
63 |
+
if resume:
|
64 |
+
previous_checkpoint = os.path.join(save_directory, "aligner.pt")
|
65 |
+
path_to_checkpoint = previous_checkpoint
|
66 |
+
fine_tune = False
|
67 |
+
|
68 |
+
if path_to_checkpoint is not None:
|
69 |
+
check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
|
70 |
+
asr_model.load_state_dict(check_dict["asr_model"])
|
71 |
+
tiny_tts.load_state_dict(check_dict["tts_model"])
|
72 |
+
if not fine_tune:
|
73 |
+
optim_asr.load_state_dict(check_dict["optimizer"])
|
74 |
+
optim_tts.load_state_dict(check_dict["tts_optimizer"])
|
75 |
+
step_counter = check_dict["step_counter"]
|
76 |
+
if step_counter > steps:
|
77 |
+
print("Desired steps already reached in loaded checkpoint.")
|
78 |
+
return
|
79 |
+
start_time = time.time()
|
80 |
+
|
81 |
+
while True:
|
82 |
+
loss_sum = list()
|
83 |
+
|
84 |
+
asr_model.train()
|
85 |
+
tiny_tts.train()
|
86 |
+
for batch in tqdm(train_loader):
|
87 |
+
tokens = batch[0].to(device)
|
88 |
+
tokens_len = batch[1].to(device)
|
89 |
+
mel = batch[2].to(device)
|
90 |
+
mel_len = batch[3].to(device)
|
91 |
+
speaker_embeddings = batch[4].to(device)
|
92 |
+
|
93 |
+
pred = asr_model(mel, mel_len)
|
94 |
+
|
95 |
+
ctc_loss = asr_model.ctc_loss(pred.transpose(0, 1).log_softmax(2),
|
96 |
+
tokens,
|
97 |
+
mel_len,
|
98 |
+
tokens_len)
|
99 |
+
|
100 |
+
if use_reconstruction:
|
101 |
+
speaker_embeddings_expanded = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(1).expand(-1, pred.size(1), -1)
|
102 |
+
tts_lambda = min([5, step_counter / 2000]) # super simple schedule
|
103 |
+
reconstruction_loss = tiny_tts(x=torch.cat([pred, speaker_embeddings_expanded], dim=-1),
|
104 |
+
# combine ASR prediction with speaker embeddings to allow for reconstruction loss on multiple speakers
|
105 |
+
lens=mel_len,
|
106 |
+
ys=mel) * tts_lambda # reconstruction loss to make the states more distinct
|
107 |
+
loss = ctc_loss + reconstruction_loss
|
108 |
+
else:
|
109 |
+
loss = ctc_loss
|
110 |
+
|
111 |
+
optim_asr.zero_grad()
|
112 |
+
if use_reconstruction:
|
113 |
+
optim_tts.zero_grad()
|
114 |
+
loss.backward()
|
115 |
+
torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
|
116 |
+
if use_reconstruction:
|
117 |
+
torch.nn.utils.clip_grad_norm_(tiny_tts.parameters(), 1.0)
|
118 |
+
optim_asr.step()
|
119 |
+
if use_reconstruction:
|
120 |
+
optim_tts.step()
|
121 |
+
|
122 |
+
step_counter += 1
|
123 |
+
|
124 |
+
loss_sum.append(loss.item())
|
125 |
+
|
126 |
+
asr_model.eval()
|
127 |
+
loss_this_epoch = sum(loss_sum) / len(loss_sum)
|
128 |
+
torch.save({
|
129 |
+
"asr_model" : asr_model.state_dict(),
|
130 |
+
"optimizer" : optim_asr.state_dict(),
|
131 |
+
"tts_model" : tiny_tts.state_dict(),
|
132 |
+
"tts_optimizer": optim_tts.state_dict(),
|
133 |
+
"step_counter" : step_counter,
|
134 |
+
},
|
135 |
+
os.path.join(save_directory, "aligner.pt"))
|
136 |
+
print("Total Loss: {}".format(round(loss_this_epoch, 3)))
|
137 |
+
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
|
138 |
+
print("Steps: {}".format(step_counter))
|
139 |
+
if debug_img_path is not None:
|
140 |
+
asr_model.inference(mel=mel[0][:mel_len[0]],
|
141 |
+
tokens=tokens[0][:tokens_len[0]],
|
142 |
+
save_img_for_debug=debug_img_path + f"/{step_counter}.png",
|
143 |
+
train=True) # for testing
|
144 |
+
if step_counter > steps:
|
145 |
+
return
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/DurationCalculator.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Nagoya University (Tomoki Hayashi)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class DurationCalculator(torch.nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, reduction_factor):
|
13 |
+
self.reduction_factor = reduction_factor
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def forward(self, att_ws, vis=None):
|
18 |
+
"""
|
19 |
+
Convert alignment matrix to durations.
|
20 |
+
"""
|
21 |
+
if vis is not None:
|
22 |
+
plt.figure(figsize=(8, 4))
|
23 |
+
plt.imshow(att_ws.cpu().numpy(), interpolation='nearest', aspect='auto', origin="lower")
|
24 |
+
plt.xlabel("Inputs")
|
25 |
+
plt.ylabel("Outputs")
|
26 |
+
plt.tight_layout()
|
27 |
+
plt.savefig(vis)
|
28 |
+
plt.close()
|
29 |
+
# calculate duration from 2d alignment matrix
|
30 |
+
durations = torch.stack([att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])])
|
31 |
+
return durations.view(-1) * self.reduction_factor
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/EnergyCalculator.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Nagoya University (Tomoki Hayashi)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from Layers.STFT import STFT
|
9 |
+
from Utility.utils import pad_list
|
10 |
+
|
11 |
+
|
12 |
+
class EnergyCalculator(torch.nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, fs=16000, n_fft=1024, win_length=None, hop_length=256, window="hann", center=True,
|
15 |
+
normalized=False, onesided=True, use_token_averaged_energy=True, reduction_factor=1):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.fs = fs
|
19 |
+
self.n_fft = n_fft
|
20 |
+
self.hop_length = hop_length
|
21 |
+
self.win_length = win_length
|
22 |
+
self.window = window
|
23 |
+
self.use_token_averaged_energy = use_token_averaged_energy
|
24 |
+
if use_token_averaged_energy:
|
25 |
+
assert reduction_factor >= 1
|
26 |
+
self.reduction_factor = reduction_factor
|
27 |
+
|
28 |
+
self.stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided)
|
29 |
+
|
30 |
+
def output_size(self):
|
31 |
+
return 1
|
32 |
+
|
33 |
+
def get_parameters(self):
|
34 |
+
return dict(fs=self.fs, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, win_length=self.win_length, center=self.stft.center,
|
35 |
+
normalized=self.stft.normalized, use_token_averaged_energy=self.use_token_averaged_energy, reduction_factor=self.reduction_factor)
|
36 |
+
|
37 |
+
def forward(self, input_waves, input_waves_lengths=None, feats_lengths=None, durations=None,
|
38 |
+
durations_lengths=None, norm_by_average=True):
|
39 |
+
# If not provided, we assume that the inputs have the same length
|
40 |
+
if input_waves_lengths is None:
|
41 |
+
input_waves_lengths = (input_waves.new_ones(input_waves.shape[0], dtype=torch.long) * input_waves.shape[1])
|
42 |
+
|
43 |
+
# Domain-conversion: e.g. Stft: time -> time-freq
|
44 |
+
input_stft, energy_lengths = self.stft(input_waves, input_waves_lengths)
|
45 |
+
|
46 |
+
assert input_stft.dim() >= 4, input_stft.shape
|
47 |
+
assert input_stft.shape[-1] == 2, input_stft.shape
|
48 |
+
|
49 |
+
# input_stft: (..., F, 2) -> (..., F)
|
50 |
+
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
|
51 |
+
# sum over frequency (B, N, F) -> (B, N)
|
52 |
+
energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10))
|
53 |
+
|
54 |
+
# (Optional): Adjust length to match with the mel-spectrogram
|
55 |
+
if feats_lengths is not None:
|
56 |
+
energy = [self._adjust_num_frames(e[:el].view(-1), fl) for e, el, fl in zip(energy, energy_lengths, feats_lengths)]
|
57 |
+
energy_lengths = feats_lengths
|
58 |
+
|
59 |
+
# (Optional): Average by duration to calculate token-wise energy
|
60 |
+
if self.use_token_averaged_energy:
|
61 |
+
energy = [self._average_by_duration(e[:el].view(-1), d) for e, el, d in zip(energy, energy_lengths, durations)]
|
62 |
+
energy_lengths = durations_lengths
|
63 |
+
|
64 |
+
# Padding
|
65 |
+
if isinstance(energy, list):
|
66 |
+
energy = pad_list(energy, 0.0)
|
67 |
+
|
68 |
+
# Return with the shape (B, T, 1)
|
69 |
+
if norm_by_average:
|
70 |
+
average = energy[0][energy[0] != 0.0].mean()
|
71 |
+
energy = energy / average
|
72 |
+
return energy.unsqueeze(-1), energy_lengths
|
73 |
+
|
74 |
+
def _average_by_duration(self, x, d):
|
75 |
+
assert 0 <= len(x) - d.sum() < self.reduction_factor
|
76 |
+
d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
|
77 |
+
x_avg = [x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0) for start, end in zip(d_cumsum[:-1], d_cumsum[1:])]
|
78 |
+
return torch.stack(x_avg)
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def _adjust_num_frames(x, num_frames):
|
82 |
+
if num_frames > len(x):
|
83 |
+
x = F.pad(x, (0, num_frames - len(x)))
|
84 |
+
elif num_frames < len(x):
|
85 |
+
x = x[:num_frames]
|
86 |
+
return x
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
from abc import ABC
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from Layers.Conformer import Conformer
|
10 |
+
from Layers.DurationPredictor import DurationPredictor
|
11 |
+
from Layers.LengthRegulator import LengthRegulator
|
12 |
+
from Layers.PostNet import PostNet
|
13 |
+
from Layers.VariancePredictor import VariancePredictor
|
14 |
+
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.FastSpeech2Loss import FastSpeech2Loss
|
15 |
+
from Utility.SoftDTW.sdtw_cuda_loss import SoftDTW
|
16 |
+
from Utility.utils import initialize
|
17 |
+
from Utility.utils import make_non_pad_mask
|
18 |
+
from Utility.utils import make_pad_mask
|
19 |
+
|
20 |
+
|
21 |
+
class FastSpeech2(torch.nn.Module, ABC):
|
22 |
+
"""
|
23 |
+
FastSpeech 2 module.
|
24 |
+
|
25 |
+
This is a module of FastSpeech 2 described in FastSpeech 2: Fast and
|
26 |
+
High-Quality End-to-End Text to Speech. Instead of quantized pitch and
|
27 |
+
energy, we use token-averaged value introduced in FastPitch: Parallel
|
28 |
+
Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers
|
29 |
+
instead of regular Transformers.
|
30 |
+
|
31 |
+
https://arxiv.org/abs/2006.04558
|
32 |
+
https://arxiv.org/abs/2006.06873
|
33 |
+
https://arxiv.org/pdf/2005.08100
|
34 |
+
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
# network structure related
|
39 |
+
idim=66,
|
40 |
+
odim=80,
|
41 |
+
adim=384,
|
42 |
+
aheads=4,
|
43 |
+
elayers=6,
|
44 |
+
eunits=1536,
|
45 |
+
dlayers=6,
|
46 |
+
dunits=1536,
|
47 |
+
postnet_layers=5,
|
48 |
+
postnet_chans=256,
|
49 |
+
postnet_filts=5,
|
50 |
+
positionwise_layer_type="conv1d",
|
51 |
+
positionwise_conv_kernel_size=1,
|
52 |
+
use_scaled_pos_enc=True,
|
53 |
+
use_batch_norm=True,
|
54 |
+
encoder_normalize_before=True,
|
55 |
+
decoder_normalize_before=True,
|
56 |
+
encoder_concat_after=False,
|
57 |
+
decoder_concat_after=False,
|
58 |
+
reduction_factor=1,
|
59 |
+
# encoder / decoder
|
60 |
+
use_macaron_style_in_conformer=True,
|
61 |
+
use_cnn_in_conformer=True,
|
62 |
+
conformer_enc_kernel_size=7,
|
63 |
+
conformer_dec_kernel_size=31,
|
64 |
+
# duration predictor
|
65 |
+
duration_predictor_layers=2,
|
66 |
+
duration_predictor_chans=256,
|
67 |
+
duration_predictor_kernel_size=3,
|
68 |
+
# energy predictor
|
69 |
+
energy_predictor_layers=2,
|
70 |
+
energy_predictor_chans=256,
|
71 |
+
energy_predictor_kernel_size=3,
|
72 |
+
energy_predictor_dropout=0.5,
|
73 |
+
energy_embed_kernel_size=1,
|
74 |
+
energy_embed_dropout=0.0,
|
75 |
+
stop_gradient_from_energy_predictor=False,
|
76 |
+
# pitch predictor
|
77 |
+
pitch_predictor_layers=5,
|
78 |
+
pitch_predictor_chans=256,
|
79 |
+
pitch_predictor_kernel_size=5,
|
80 |
+
pitch_predictor_dropout=0.5,
|
81 |
+
pitch_embed_kernel_size=1,
|
82 |
+
pitch_embed_dropout=0.0,
|
83 |
+
stop_gradient_from_pitch_predictor=True,
|
84 |
+
# training related
|
85 |
+
transformer_enc_dropout_rate=0.2,
|
86 |
+
transformer_enc_positional_dropout_rate=0.2,
|
87 |
+
transformer_enc_attn_dropout_rate=0.2,
|
88 |
+
transformer_dec_dropout_rate=0.2,
|
89 |
+
transformer_dec_positional_dropout_rate=0.2,
|
90 |
+
transformer_dec_attn_dropout_rate=0.2,
|
91 |
+
duration_predictor_dropout_rate=0.2,
|
92 |
+
postnet_dropout_rate=0.5,
|
93 |
+
init_type="xavier_uniform",
|
94 |
+
init_enc_alpha=1.0,
|
95 |
+
init_dec_alpha=1.0,
|
96 |
+
use_masking=False,
|
97 |
+
use_weighted_masking=True,
|
98 |
+
# additional features
|
99 |
+
use_dtw_loss=False,
|
100 |
+
utt_embed_dim=704,
|
101 |
+
connect_utt_emb_at_encoder_out=True,
|
102 |
+
lang_embs=100):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
# store hyperparameters
|
106 |
+
self.idim = idim
|
107 |
+
self.odim = odim
|
108 |
+
self.use_dtw_loss = use_dtw_loss
|
109 |
+
self.eos = 1
|
110 |
+
self.reduction_factor = reduction_factor
|
111 |
+
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
|
112 |
+
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
|
113 |
+
self.use_scaled_pos_enc = use_scaled_pos_enc
|
114 |
+
self.multilingual_model = lang_embs is not None
|
115 |
+
self.multispeaker_model = utt_embed_dim is not None
|
116 |
+
|
117 |
+
# define encoder
|
118 |
+
embed = torch.nn.Sequential(torch.nn.Linear(idim, 100),
|
119 |
+
torch.nn.Tanh(),
|
120 |
+
torch.nn.Linear(100, adim))
|
121 |
+
self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers,
|
122 |
+
input_layer=embed, dropout_rate=transformer_enc_dropout_rate,
|
123 |
+
positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate,
|
124 |
+
normalize_before=encoder_normalize_before, concat_after=encoder_concat_after,
|
125 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer,
|
126 |
+
use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False,
|
127 |
+
utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs)
|
128 |
+
|
129 |
+
# define duration predictor
|
130 |
+
self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans,
|
131 |
+
kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, )
|
132 |
+
|
133 |
+
# define pitch predictor
|
134 |
+
self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans,
|
135 |
+
kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout)
|
136 |
+
# continuous pitch + FastPitch style avg
|
137 |
+
self.pitch_embed = torch.nn.Sequential(
|
138 |
+
torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2),
|
139 |
+
torch.nn.Dropout(pitch_embed_dropout))
|
140 |
+
|
141 |
+
# define energy predictor
|
142 |
+
self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans,
|
143 |
+
kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout)
|
144 |
+
# continuous energy + FastPitch style avg
|
145 |
+
self.energy_embed = torch.nn.Sequential(
|
146 |
+
torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2),
|
147 |
+
torch.nn.Dropout(energy_embed_dropout))
|
148 |
+
|
149 |
+
# define length regulator
|
150 |
+
self.length_regulator = LengthRegulator()
|
151 |
+
|
152 |
+
self.decoder = Conformer(idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None,
|
153 |
+
dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
154 |
+
attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before,
|
155 |
+
concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
156 |
+
macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size)
|
157 |
+
|
158 |
+
# define final projection
|
159 |
+
self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
|
160 |
+
|
161 |
+
# define postnet
|
162 |
+
self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm,
|
163 |
+
dropout_rate=postnet_dropout_rate)
|
164 |
+
|
165 |
+
# initialize parameters
|
166 |
+
self._reset_parameters(init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha)
|
167 |
+
|
168 |
+
# define criterions
|
169 |
+
self.criterion = FastSpeech2Loss(use_masking=use_masking, use_weighted_masking=use_weighted_masking)
|
170 |
+
self.dtw_criterion = SoftDTW(use_cuda=True, gamma=0.1)
|
171 |
+
|
172 |
+
def forward(self,
|
173 |
+
text_tensors,
|
174 |
+
text_lengths,
|
175 |
+
gold_speech,
|
176 |
+
speech_lengths,
|
177 |
+
gold_durations,
|
178 |
+
gold_pitch,
|
179 |
+
gold_energy,
|
180 |
+
utterance_embedding,
|
181 |
+
return_mels=False,
|
182 |
+
lang_ids=None):
|
183 |
+
"""
|
184 |
+
Calculate forward propagation.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
return_mels: whether to return the predicted spectrogram
|
188 |
+
text_tensors (LongTensor): Batch of padded text vectors (B, Tmax).
|
189 |
+
text_lengths (LongTensor): Batch of lengths of each input (B,).
|
190 |
+
gold_speech (Tensor): Batch of padded target features (B, Lmax, odim).
|
191 |
+
speech_lengths (LongTensor): Batch of the lengths of each target (B,).
|
192 |
+
gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1).
|
193 |
+
gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1).
|
194 |
+
gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1).
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
Tensor: Loss scalar value.
|
198 |
+
Dict: Statistics to be monitored.
|
199 |
+
Tensor: Weight value.
|
200 |
+
"""
|
201 |
+
# Texts include EOS token from the teacher model already in this version
|
202 |
+
|
203 |
+
# forward propagation
|
204 |
+
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(text_tensors, text_lengths, gold_speech, speech_lengths,
|
205 |
+
gold_durations, gold_pitch, gold_energy, utterance_embedding=utterance_embedding,
|
206 |
+
is_inference=False, lang_ids=lang_ids)
|
207 |
+
|
208 |
+
# modify mod part of groundtruth (speaking pace)
|
209 |
+
if self.reduction_factor > 1:
|
210 |
+
speech_lengths = speech_lengths.new([olen - olen % self.reduction_factor for olen in speech_lengths])
|
211 |
+
|
212 |
+
# calculate loss
|
213 |
+
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs,
|
214 |
+
e_outs=e_outs, ys=gold_speech, ds=gold_durations, ps=gold_pitch, es=gold_energy,
|
215 |
+
ilens=text_lengths, olens=speech_lengths)
|
216 |
+
loss = l1_loss + duration_loss + pitch_loss + energy_loss
|
217 |
+
|
218 |
+
if self.use_dtw_loss:
|
219 |
+
# print("Regular Loss: {}".format(loss))
|
220 |
+
dtw_loss = self.dtw_criterion(after_outs, gold_speech).mean() / 2000.0 # division to balance orders of magnitude
|
221 |
+
# print("DTW Loss: {}".format(dtw_loss))
|
222 |
+
loss = loss + dtw_loss
|
223 |
+
|
224 |
+
if return_mels:
|
225 |
+
return loss, after_outs
|
226 |
+
return loss
|
227 |
+
|
228 |
+
def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None,
|
229 |
+
gold_durations=None, gold_pitch=None, gold_energy=None,
|
230 |
+
is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None):
|
231 |
+
|
232 |
+
if not self.multilingual_model:
|
233 |
+
lang_ids = None
|
234 |
+
|
235 |
+
if not self.multispeaker_model:
|
236 |
+
utterance_embedding = None
|
237 |
+
|
238 |
+
# forward encoder
|
239 |
+
text_masks = self._source_mask(text_lens)
|
240 |
+
|
241 |
+
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim)
|
242 |
+
|
243 |
+
# forward duration predictor and variance predictors
|
244 |
+
d_masks = make_pad_mask(text_lens, device=text_lens.device)
|
245 |
+
|
246 |
+
if self.stop_gradient_from_pitch_predictor:
|
247 |
+
pitch_predictions = self.pitch_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1))
|
248 |
+
else:
|
249 |
+
pitch_predictions = self.pitch_predictor(encoded_texts, d_masks.unsqueeze(-1))
|
250 |
+
|
251 |
+
if self.stop_gradient_from_energy_predictor:
|
252 |
+
energy_predictions = self.energy_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1))
|
253 |
+
else:
|
254 |
+
energy_predictions = self.energy_predictor(encoded_texts, d_masks.unsqueeze(-1))
|
255 |
+
|
256 |
+
if is_inference:
|
257 |
+
d_outs = self.duration_predictor.inference(encoded_texts, d_masks) # (B, Tmax)
|
258 |
+
# use prediction in inference
|
259 |
+
p_embs = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
|
260 |
+
e_embs = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
|
261 |
+
encoded_texts = encoded_texts + e_embs + p_embs
|
262 |
+
encoded_texts = self.length_regulator(encoded_texts, d_outs, alpha) # (B, Lmax, adim)
|
263 |
+
else:
|
264 |
+
d_outs = self.duration_predictor(encoded_texts, d_masks)
|
265 |
+
|
266 |
+
# use groundtruth in training
|
267 |
+
p_embs = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
|
268 |
+
e_embs = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
|
269 |
+
encoded_texts = encoded_texts + e_embs + p_embs
|
270 |
+
encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim)
|
271 |
+
|
272 |
+
# forward decoder
|
273 |
+
if speech_lens is not None and not is_inference:
|
274 |
+
if self.reduction_factor > 1:
|
275 |
+
olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens])
|
276 |
+
else:
|
277 |
+
olens_in = speech_lens
|
278 |
+
h_masks = self._source_mask(olens_in)
|
279 |
+
else:
|
280 |
+
h_masks = None
|
281 |
+
zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim)
|
282 |
+
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
|
283 |
+
|
284 |
+
# postnet -> (B, Lmax//r * r, odim)
|
285 |
+
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
|
286 |
+
|
287 |
+
return before_outs, after_outs, d_outs, pitch_predictions, energy_predictions
|
288 |
+
|
289 |
+
def batch_inference(self, texts, text_lens, utt_emb):
|
290 |
+
_, after_outs, d_outs, _, _ = self._forward(texts,
|
291 |
+
text_lens,
|
292 |
+
None,
|
293 |
+
is_inference=True,
|
294 |
+
alpha=1.0)
|
295 |
+
return after_outs, d_outs
|
296 |
+
|
297 |
+
def inference(self,
|
298 |
+
text,
|
299 |
+
speech=None,
|
300 |
+
durations=None,
|
301 |
+
pitch=None,
|
302 |
+
energy=None,
|
303 |
+
alpha=1.0,
|
304 |
+
use_teacher_forcing=False,
|
305 |
+
utterance_embedding=None,
|
306 |
+
return_duration_pitch_energy=False,
|
307 |
+
lang_id=None):
|
308 |
+
"""
|
309 |
+
Generate the sequence of features given the sequences of characters.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
text (LongTensor): Input sequence of characters (T,).
|
313 |
+
speech (Tensor, optional): Feature sequence to extract style (N, idim).
|
314 |
+
durations (LongTensor, optional): Groundtruth of duration (T + 1,).
|
315 |
+
pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1).
|
316 |
+
energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1).
|
317 |
+
alpha (float, optional): Alpha to control the speed.
|
318 |
+
use_teacher_forcing (bool, optional): Whether to use teacher forcing.
|
319 |
+
If true, groundtruth of duration, pitch and energy will be used.
|
320 |
+
return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
Tensor: Output sequence of features (L, odim).
|
324 |
+
|
325 |
+
"""
|
326 |
+
self.eval()
|
327 |
+
x, y = text, speech
|
328 |
+
d, p, e = durations, pitch, energy
|
329 |
+
|
330 |
+
# setup batch axis
|
331 |
+
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
|
332 |
+
xs, ys = x.unsqueeze(0), None
|
333 |
+
if y is not None:
|
334 |
+
ys = y.unsqueeze(0)
|
335 |
+
if lang_id is not None:
|
336 |
+
lang_id = lang_id.unsqueeze(0)
|
337 |
+
|
338 |
+
if use_teacher_forcing:
|
339 |
+
# use groundtruth of duration, pitch, and energy
|
340 |
+
ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0)
|
341 |
+
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs,
|
342 |
+
ilens,
|
343 |
+
ys,
|
344 |
+
gold_durations=ds,
|
345 |
+
gold_pitch=ps,
|
346 |
+
gold_energy=es,
|
347 |
+
utterance_embedding=utterance_embedding.unsqueeze(0),
|
348 |
+
lang_ids=lang_id) # (1, L, odim)
|
349 |
+
else:
|
350 |
+
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs,
|
351 |
+
ilens,
|
352 |
+
ys,
|
353 |
+
is_inference=True,
|
354 |
+
alpha=alpha,
|
355 |
+
utterance_embedding=utterance_embedding.unsqueeze(0),
|
356 |
+
lang_ids=lang_id) # (1, L, odim)
|
357 |
+
self.train()
|
358 |
+
if return_duration_pitch_energy:
|
359 |
+
return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0]
|
360 |
+
return after_outs[0]
|
361 |
+
|
362 |
+
def _source_mask(self, ilens):
|
363 |
+
"""
|
364 |
+
Make masks for self-attention.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
ilens (LongTensor): Batch of lengths (B,).
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
Tensor: Mask tensor for self-attention.
|
371 |
+
|
372 |
+
"""
|
373 |
+
x_masks = make_non_pad_mask(ilens, device=ilens.device)
|
374 |
+
return x_masks.unsqueeze(-2)
|
375 |
+
|
376 |
+
def _reset_parameters(self, init_type, init_enc_alpha, init_dec_alpha):
|
377 |
+
# initialize parameters
|
378 |
+
if init_type != "pytorch":
|
379 |
+
initialize(self, init_type)
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeech2Loss.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from ESPNet
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from Layers.DurationPredictor import DurationPredictorLoss
|
8 |
+
from Utility.utils import make_non_pad_mask
|
9 |
+
|
10 |
+
|
11 |
+
class FastSpeech2Loss(torch.nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, use_masking=True, use_weighted_masking=False):
|
14 |
+
"""
|
15 |
+
use_masking (bool):
|
16 |
+
Whether to apply masking for padded part in loss calculation.
|
17 |
+
use_weighted_masking (bool):
|
18 |
+
Whether to weighted masking in loss calculation.
|
19 |
+
"""
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
assert (use_masking != use_weighted_masking) or not use_masking
|
23 |
+
self.use_masking = use_masking
|
24 |
+
self.use_weighted_masking = use_weighted_masking
|
25 |
+
|
26 |
+
# define criterions
|
27 |
+
reduction = "none" if self.use_weighted_masking else "mean"
|
28 |
+
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
|
29 |
+
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
|
30 |
+
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
|
31 |
+
|
32 |
+
def forward(self, after_outs, before_outs, d_outs, p_outs, e_outs, ys,
|
33 |
+
ds, ps, es, ilens, olens, ):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
|
37 |
+
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
|
38 |
+
d_outs (LongTensor): Batch of outputs of duration predictor (B, Tmax).
|
39 |
+
p_outs (Tensor): Batch of outputs of pitch predictor (B, Tmax, 1).
|
40 |
+
e_outs (Tensor): Batch of outputs of energy predictor (B, Tmax, 1).
|
41 |
+
ys (Tensor): Batch of target features (B, Lmax, odim).
|
42 |
+
ds (LongTensor): Batch of durations (B, Tmax).
|
43 |
+
ps (Tensor): Batch of target token-averaged pitch (B, Tmax, 1).
|
44 |
+
es (Tensor): Batch of target token-averaged energy (B, Tmax, 1).
|
45 |
+
ilens (LongTensor): Batch of the lengths of each input (B,).
|
46 |
+
olens (LongTensor): Batch of the lengths of each target (B,).
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Tensor: L1 loss value.
|
50 |
+
Tensor: Duration predictor loss value.
|
51 |
+
Tensor: Pitch predictor loss value.
|
52 |
+
Tensor: Energy predictor loss value.
|
53 |
+
|
54 |
+
"""
|
55 |
+
# apply mask to remove padded part
|
56 |
+
if self.use_masking:
|
57 |
+
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
|
58 |
+
before_outs = before_outs.masked_select(out_masks)
|
59 |
+
if after_outs is not None:
|
60 |
+
after_outs = after_outs.masked_select(out_masks)
|
61 |
+
ys = ys.masked_select(out_masks)
|
62 |
+
duration_masks = make_non_pad_mask(ilens).to(ys.device)
|
63 |
+
d_outs = d_outs.masked_select(duration_masks)
|
64 |
+
ds = ds.masked_select(duration_masks)
|
65 |
+
pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device)
|
66 |
+
p_outs = p_outs.masked_select(pitch_masks)
|
67 |
+
e_outs = e_outs.masked_select(pitch_masks)
|
68 |
+
ps = ps.masked_select(pitch_masks)
|
69 |
+
es = es.masked_select(pitch_masks)
|
70 |
+
|
71 |
+
# calculate loss
|
72 |
+
l1_loss = self.l1_criterion(before_outs, ys)
|
73 |
+
if after_outs is not None:
|
74 |
+
l1_loss += self.l1_criterion(after_outs, ys)
|
75 |
+
duration_loss = self.duration_criterion(d_outs, ds)
|
76 |
+
pitch_loss = self.mse_criterion(p_outs, ps)
|
77 |
+
energy_loss = self.mse_criterion(e_outs, es)
|
78 |
+
|
79 |
+
# make weighted mask and apply it
|
80 |
+
if self.use_weighted_masking:
|
81 |
+
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
|
82 |
+
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
|
83 |
+
out_weights /= ys.size(0) * ys.size(2)
|
84 |
+
duration_masks = make_non_pad_mask(ilens).to(ys.device)
|
85 |
+
duration_weights = (duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float())
|
86 |
+
duration_weights /= ds.size(0)
|
87 |
+
|
88 |
+
# apply weight
|
89 |
+
l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
|
90 |
+
duration_loss = (duration_loss.mul(duration_weights).masked_select(duration_masks).sum())
|
91 |
+
pitch_masks = duration_masks.unsqueeze(-1)
|
92 |
+
pitch_weights = duration_weights.unsqueeze(-1)
|
93 |
+
pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum()
|
94 |
+
energy_loss = (energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum())
|
95 |
+
|
96 |
+
return l1_loss, duration_loss, pitch_loss, energy_loss
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/FastSpeechDatasetLanguageID.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import statistics
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
|
9 |
+
from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor
|
10 |
+
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
|
11 |
+
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset
|
12 |
+
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator
|
13 |
+
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator
|
14 |
+
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio
|
15 |
+
|
16 |
+
|
17 |
+
class FastSpeechDataset(Dataset):
|
18 |
+
|
19 |
+
def __init__(self,
|
20 |
+
path_to_transcript_dict,
|
21 |
+
acoustic_checkpoint_path,
|
22 |
+
cache_dir,
|
23 |
+
lang,
|
24 |
+
loading_processes=40,
|
25 |
+
min_len_in_seconds=1,
|
26 |
+
max_len_in_seconds=20,
|
27 |
+
cut_silence=False,
|
28 |
+
reduction_factor=1,
|
29 |
+
device=torch.device("cpu"),
|
30 |
+
rebuild_cache=False,
|
31 |
+
ctc_selection=True,
|
32 |
+
save_imgs=False):
|
33 |
+
self.cache_dir = cache_dir
|
34 |
+
os.makedirs(cache_dir, exist_ok=True)
|
35 |
+
if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache:
|
36 |
+
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
|
37 |
+
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
|
38 |
+
cache_dir=cache_dir,
|
39 |
+
lang=lang,
|
40 |
+
loading_processes=loading_processes,
|
41 |
+
min_len_in_seconds=min_len_in_seconds,
|
42 |
+
max_len_in_seconds=max_len_in_seconds,
|
43 |
+
cut_silences=cut_silence,
|
44 |
+
rebuild_cache=rebuild_cache,
|
45 |
+
device=device)
|
46 |
+
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
|
47 |
+
# we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech.
|
48 |
+
if not isinstance(datapoints, tuple): # check for backwards compatibility
|
49 |
+
print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.")
|
50 |
+
AlignerDataset(path_to_transcript_dict=path_to_transcript_dict,
|
51 |
+
cache_dir=cache_dir,
|
52 |
+
lang=lang,
|
53 |
+
loading_processes=loading_processes,
|
54 |
+
min_len_in_seconds=min_len_in_seconds,
|
55 |
+
max_len_in_seconds=max_len_in_seconds,
|
56 |
+
cut_silences=cut_silence,
|
57 |
+
rebuild_cache=True)
|
58 |
+
datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu')
|
59 |
+
dataset = datapoints[0]
|
60 |
+
norm_waves = datapoints[1]
|
61 |
+
|
62 |
+
# build cache
|
63 |
+
print("... building dataset cache ...")
|
64 |
+
self.datapoints = list()
|
65 |
+
self.ctc_losses = list()
|
66 |
+
|
67 |
+
acoustic_model = Aligner()
|
68 |
+
acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"])
|
69 |
+
|
70 |
+
# ==========================================
|
71 |
+
# actual creation of datapoints starts here
|
72 |
+
# ==========================================
|
73 |
+
|
74 |
+
acoustic_model = acoustic_model.to(device)
|
75 |
+
dio = Dio(reduction_factor=reduction_factor, fs=16000)
|
76 |
+
energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000)
|
77 |
+
dc = DurationCalculator(reduction_factor=reduction_factor)
|
78 |
+
vis_dir = os.path.join(cache_dir, "duration_vis")
|
79 |
+
os.makedirs(vis_dir, exist_ok=True)
|
80 |
+
pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device)
|
81 |
+
|
82 |
+
for index in tqdm(range(len(dataset))):
|
83 |
+
norm_wave = norm_waves[index]
|
84 |
+
norm_wave_length = torch.LongTensor([len(norm_wave)])
|
85 |
+
|
86 |
+
if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection:
|
87 |
+
continue
|
88 |
+
|
89 |
+
text = dataset[index][0]
|
90 |
+
melspec = dataset[index][2]
|
91 |
+
melspec_length = dataset[index][3]
|
92 |
+
|
93 |
+
alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device),
|
94 |
+
tokens=text.to(device),
|
95 |
+
save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None,
|
96 |
+
return_ctc=True)
|
97 |
+
|
98 |
+
cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu()
|
99 |
+
|
100 |
+
last_vec = None
|
101 |
+
for phoneme_index, vec in enumerate(text):
|
102 |
+
if last_vec is not None:
|
103 |
+
if last_vec.numpy().tolist() == vec.numpy().tolist():
|
104 |
+
# we found a case of repeating phonemes!
|
105 |
+
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
|
106 |
+
dur_1 = cached_duration[phoneme_index - 1]
|
107 |
+
dur_2 = cached_duration[phoneme_index]
|
108 |
+
total_dur = dur_1 + dur_2
|
109 |
+
new_dur_1 = int((total_dur / 5) * 3)
|
110 |
+
new_dur_2 = total_dur - new_dur_1
|
111 |
+
cached_duration[phoneme_index - 1] = new_dur_1
|
112 |
+
cached_duration[phoneme_index] = new_dur_2
|
113 |
+
last_vec = vec
|
114 |
+
|
115 |
+
cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0),
|
116 |
+
input_waves_lengths=norm_wave_length,
|
117 |
+
feats_lengths=melspec_length,
|
118 |
+
durations=cached_duration.unsqueeze(0),
|
119 |
+
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
|
120 |
+
|
121 |
+
cached_pitch = dio(input_waves=norm_wave.unsqueeze(0),
|
122 |
+
input_waves_lengths=norm_wave_length,
|
123 |
+
feats_lengths=melspec_length,
|
124 |
+
durations=cached_duration.unsqueeze(0),
|
125 |
+
durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu()
|
126 |
+
|
127 |
+
try:
|
128 |
+
prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu()
|
129 |
+
except RuntimeError:
|
130 |
+
# if there is an audio without any voiced segments whatsoever we have to skip it.
|
131 |
+
continue
|
132 |
+
|
133 |
+
self.datapoints.append([dataset[index][0],
|
134 |
+
dataset[index][1],
|
135 |
+
dataset[index][2],
|
136 |
+
dataset[index][3],
|
137 |
+
cached_duration.cpu(),
|
138 |
+
cached_energy,
|
139 |
+
cached_pitch,
|
140 |
+
prosodic_condition])
|
141 |
+
self.ctc_losses.append(ctc_loss)
|
142 |
+
|
143 |
+
# =============================
|
144 |
+
# done with datapoint creation
|
145 |
+
# =============================
|
146 |
+
|
147 |
+
if ctc_selection:
|
148 |
+
# now we can filter out some bad datapoints based on the CTC scores we collected
|
149 |
+
mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses)
|
150 |
+
std_dev = statistics.stdev(self.ctc_losses)
|
151 |
+
threshold = mean_ctc + std_dev
|
152 |
+
for index in range(len(self.ctc_losses), 0, -1):
|
153 |
+
if self.ctc_losses[index - 1] > threshold:
|
154 |
+
self.datapoints.pop(index - 1)
|
155 |
+
print(
|
156 |
+
f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}")
|
157 |
+
|
158 |
+
# save to cache
|
159 |
+
if len(self.datapoints) > 0:
|
160 |
+
torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt"))
|
161 |
+
else:
|
162 |
+
import sys
|
163 |
+
print("No datapoints were prepared! Exiting...")
|
164 |
+
sys.exit()
|
165 |
+
else:
|
166 |
+
# just load the datapoints from cache
|
167 |
+
self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu')
|
168 |
+
|
169 |
+
self.cache_dir = cache_dir
|
170 |
+
self.language_id = get_language_id(lang)
|
171 |
+
print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.")
|
172 |
+
|
173 |
+
def __getitem__(self, index):
|
174 |
+
return self.datapoints[index][0], \
|
175 |
+
self.datapoints[index][1], \
|
176 |
+
self.datapoints[index][2], \
|
177 |
+
self.datapoints[index][3], \
|
178 |
+
self.datapoints[index][4], \
|
179 |
+
self.datapoints[index][5], \
|
180 |
+
self.datapoints[index][6], \
|
181 |
+
self.datapoints[index][7], \
|
182 |
+
self.language_id
|
183 |
+
|
184 |
+
def __len__(self):
|
185 |
+
return len(self.datapoints)
|
186 |
+
|
187 |
+
def remove_samples(self, list_of_samples_to_remove):
|
188 |
+
for remove_id in sorted(list_of_samples_to_remove, reverse=True):
|
189 |
+
self.datapoints.pop(remove_id)
|
190 |
+
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
|
191 |
+
print("Dataset updated!")
|
192 |
+
|
193 |
+
def fix_repeating_phones(self):
|
194 |
+
"""
|
195 |
+
The viterbi decoding of the durations cannot
|
196 |
+
handle repetitions. This is now solved heuristically,
|
197 |
+
but if you have a cache from before March 2022,
|
198 |
+
use this method to postprocess those cases.
|
199 |
+
"""
|
200 |
+
for datapoint_index in tqdm(list(range(len(self.datapoints)))):
|
201 |
+
last_vec = None
|
202 |
+
for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]):
|
203 |
+
if last_vec is not None:
|
204 |
+
if last_vec.numpy().tolist() == vec.numpy().tolist():
|
205 |
+
# we found a case of repeating phonemes!
|
206 |
+
# now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest)
|
207 |
+
dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1]
|
208 |
+
dur_2 = self.datapoints[datapoint_index][4][phoneme_index]
|
209 |
+
total_dur = dur_1 + dur_2
|
210 |
+
new_dur_1 = int((total_dur / 5) * 3)
|
211 |
+
new_dur_2 = total_dur - new_dur_1
|
212 |
+
self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1
|
213 |
+
self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2
|
214 |
+
print("fix applied")
|
215 |
+
last_vec = vec
|
216 |
+
torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt"))
|
217 |
+
print("Dataset updated!")
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/PitchCalculator.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Nagoya University (Tomoki Hayashi)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
# Adapted by Florian Lux 2021
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pyworld
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from scipy.interpolate import interp1d
|
10 |
+
|
11 |
+
from Utility.utils import pad_list
|
12 |
+
|
13 |
+
|
14 |
+
class Dio(torch.nn.Module):
|
15 |
+
"""
|
16 |
+
F0 estimation with dio + stonemask algortihm.
|
17 |
+
This is f0 extractor based on dio + stonemask algorithm
|
18 |
+
introduced in https://doi.org/10.1587/transinf.2015EDP7457
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, fs=16000, n_fft=1024, hop_length=256, f0min=40, f0max=400, use_token_averaged_f0=True,
|
22 |
+
use_continuous_f0=True, use_log_f0=True, reduction_factor=1):
|
23 |
+
super().__init__()
|
24 |
+
self.fs = fs
|
25 |
+
self.n_fft = n_fft
|
26 |
+
self.hop_length = hop_length
|
27 |
+
self.frame_period = 1000 * hop_length / fs
|
28 |
+
self.f0min = f0min
|
29 |
+
self.f0max = f0max
|
30 |
+
self.use_token_averaged_f0 = use_token_averaged_f0
|
31 |
+
self.use_continuous_f0 = use_continuous_f0
|
32 |
+
self.use_log_f0 = use_log_f0
|
33 |
+
if use_token_averaged_f0:
|
34 |
+
assert reduction_factor >= 1
|
35 |
+
self.reduction_factor = reduction_factor
|
36 |
+
|
37 |
+
def output_size(self):
|
38 |
+
return 1
|
39 |
+
|
40 |
+
def get_parameters(self):
|
41 |
+
return dict(fs=self.fs, n_fft=self.n_fft, hop_length=self.hop_length, f0min=self.f0min, f0max=self.f0max,
|
42 |
+
use_token_averaged_f0=self.use_token_averaged_f0, use_continuous_f0=self.use_continuous_f0, use_log_f0=self.use_log_f0,
|
43 |
+
reduction_factor=self.reduction_factor)
|
44 |
+
|
45 |
+
def forward(self, input_waves, input_waves_lengths=None, feats_lengths=None, durations=None,
|
46 |
+
durations_lengths=None, norm_by_average=True):
|
47 |
+
# If not provided, we assume that the inputs have the same length
|
48 |
+
if input_waves_lengths is None:
|
49 |
+
input_waves_lengths = (input_waves.new_ones(input_waves.shape[0], dtype=torch.long) * input_waves.shape[1])
|
50 |
+
|
51 |
+
# F0 extraction
|
52 |
+
pitch = [self._calculate_f0(x[:xl]) for x, xl in zip(input_waves, input_waves_lengths)]
|
53 |
+
|
54 |
+
# (Optional): Adjust length to match with the mel-spectrogram
|
55 |
+
if feats_lengths is not None:
|
56 |
+
pitch = [self._adjust_num_frames(p, fl).view(-1) for p, fl in zip(pitch, feats_lengths)]
|
57 |
+
|
58 |
+
# (Optional): Average by duration to calculate token-wise f0
|
59 |
+
if self.use_token_averaged_f0:
|
60 |
+
pitch = [self._average_by_duration(p, d).view(-1) for p, d in zip(pitch, durations)]
|
61 |
+
pitch_lengths = durations_lengths
|
62 |
+
else:
|
63 |
+
pitch_lengths = input_waves.new_tensor([len(p) for p in pitch], dtype=torch.long)
|
64 |
+
|
65 |
+
# Padding
|
66 |
+
pitch = pad_list(pitch, 0.0)
|
67 |
+
|
68 |
+
# Return with the shape (B, T, 1)
|
69 |
+
if norm_by_average:
|
70 |
+
average = pitch[0][pitch[0] != 0.0].mean()
|
71 |
+
pitch = pitch / average
|
72 |
+
return pitch.unsqueeze(-1), pitch_lengths
|
73 |
+
|
74 |
+
def _calculate_f0(self, input):
|
75 |
+
x = input.cpu().numpy().astype(np.double)
|
76 |
+
f0, timeaxis = pyworld.dio(x, self.fs, f0_floor=self.f0min, f0_ceil=self.f0max, frame_period=self.frame_period)
|
77 |
+
f0 = pyworld.stonemask(x, f0, timeaxis, self.fs)
|
78 |
+
if self.use_continuous_f0:
|
79 |
+
f0 = self._convert_to_continuous_f0(f0)
|
80 |
+
if self.use_log_f0:
|
81 |
+
nonzero_idxs = np.where(f0 != 0)[0]
|
82 |
+
f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
|
83 |
+
return input.new_tensor(f0.reshape(-1), dtype=torch.float)
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def _adjust_num_frames(x, num_frames):
|
87 |
+
if num_frames > len(x):
|
88 |
+
x = F.pad(x, (0, num_frames - len(x)))
|
89 |
+
elif num_frames < len(x):
|
90 |
+
x = x[:num_frames]
|
91 |
+
return x
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def _convert_to_continuous_f0(f0: np.array):
|
95 |
+
if (f0 == 0).all():
|
96 |
+
return f0
|
97 |
+
|
98 |
+
# padding start and end of f0 sequence
|
99 |
+
start_f0 = f0[f0 != 0][0]
|
100 |
+
end_f0 = f0[f0 != 0][-1]
|
101 |
+
start_idx = np.where(f0 == start_f0)[0][0]
|
102 |
+
end_idx = np.where(f0 == end_f0)[0][-1]
|
103 |
+
f0[:start_idx] = start_f0
|
104 |
+
f0[end_idx:] = end_f0
|
105 |
+
|
106 |
+
# get non-zero frame index
|
107 |
+
nonzero_idxs = np.where(f0 != 0)[0]
|
108 |
+
|
109 |
+
# perform linear interpolation
|
110 |
+
interp_fn = interp1d(nonzero_idxs, f0[nonzero_idxs])
|
111 |
+
f0 = interp_fn(np.arange(0, f0.shape[0]))
|
112 |
+
|
113 |
+
return f0
|
114 |
+
|
115 |
+
def _average_by_duration(self, x, d):
|
116 |
+
assert 0 <= len(x) - d.sum() < self.reduction_factor
|
117 |
+
d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
|
118 |
+
x_avg = [
|
119 |
+
x[start:end].masked_select(x[start:end].gt(0.0)).mean(dim=0) if len(x[start:end].masked_select(x[start:end].gt(0.0))) != 0 else x.new_tensor(0.0)
|
120 |
+
for start, end in zip(d_cumsum[:-1], d_cumsum[1:])]
|
121 |
+
return torch.stack(x_avg)
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/__init__.py
ADDED
File without changes
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import librosa.display as lbd
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch
|
7 |
+
import torch.multiprocessing
|
8 |
+
import torch.multiprocessing
|
9 |
+
from torch.cuda.amp import GradScaler
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
from torch.nn.utils.rnn import pad_sequence
|
12 |
+
from torch.utils.data.dataloader import DataLoader
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
|
16 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
|
17 |
+
from Utility.WarmupScheduler import WarmupScheduler
|
18 |
+
from Utility.utils import cumsum_durations
|
19 |
+
from Utility.utils import delete_old_checkpoints
|
20 |
+
from Utility.utils import get_most_recent_checkpoint
|
21 |
+
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def plot_progress_spec(net, device, save_dir, step, lang, default_emb):
|
25 |
+
tf = ArticulatoryCombinedTextFrontend(language=lang)
|
26 |
+
sentence = ""
|
27 |
+
if lang == "en":
|
28 |
+
sentence = "This is a complex sentence, it even has a pause!"
|
29 |
+
elif lang == "de":
|
30 |
+
sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
|
31 |
+
elif lang == "el":
|
32 |
+
sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
|
33 |
+
elif lang == "es":
|
34 |
+
sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
|
35 |
+
elif lang == "fi":
|
36 |
+
sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
|
37 |
+
elif lang == "ru":
|
38 |
+
sentence = "Это сложное предложение, в нем даже есть пауза!"
|
39 |
+
elif lang == "hu":
|
40 |
+
sentence = "Ez egy összetett mondat, még szünet is van benne!"
|
41 |
+
elif lang == "nl":
|
42 |
+
sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
|
43 |
+
elif lang == "fr":
|
44 |
+
sentence = "C'est une phrase complexe, elle a même une pause !"
|
45 |
+
phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
|
46 |
+
spec, durations, *_ = net.inference(text=phoneme_vector,
|
47 |
+
return_duration_pitch_energy=True,
|
48 |
+
utterance_embedding=default_emb,
|
49 |
+
lang_id=get_language_id(lang).to(device))
|
50 |
+
spec = spec.transpose(0, 1).to("cpu").numpy()
|
51 |
+
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
|
52 |
+
if not os.path.exists(os.path.join(save_dir, "spec")):
|
53 |
+
os.makedirs(os.path.join(save_dir, "spec"))
|
54 |
+
fig, ax = plt.subplots(nrows=1, ncols=1)
|
55 |
+
lbd.specshow(spec,
|
56 |
+
ax=ax,
|
57 |
+
sr=16000,
|
58 |
+
cmap='GnBu',
|
59 |
+
y_axis='mel',
|
60 |
+
x_axis=None,
|
61 |
+
hop_length=256)
|
62 |
+
ax.yaxis.set_visible(False)
|
63 |
+
ax.set_xticks(duration_splits, minor=True)
|
64 |
+
ax.xaxis.grid(True, which='minor')
|
65 |
+
ax.set_xticks(label_positions, minor=False)
|
66 |
+
ax.set_xticklabels(tf.get_phone_string(sentence))
|
67 |
+
ax.set_title(sentence)
|
68 |
+
plt.savefig(os.path.join(os.path.join(save_dir, "spec"), str(step) + ".png"))
|
69 |
+
plt.clf()
|
70 |
+
plt.close()
|
71 |
+
|
72 |
+
|
73 |
+
def collate_and_pad(batch):
|
74 |
+
# text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id
|
75 |
+
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
|
76 |
+
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
|
77 |
+
pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
|
78 |
+
torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
|
79 |
+
pad_sequence([datapoint[4] for datapoint in batch], batch_first=True),
|
80 |
+
pad_sequence([datapoint[5] for datapoint in batch], batch_first=True),
|
81 |
+
pad_sequence([datapoint[6] for datapoint in batch], batch_first=True),
|
82 |
+
torch.stack([datapoint[7] for datapoint in batch]).squeeze(),
|
83 |
+
torch.stack([datapoint[8] for datapoint in batch]))
|
84 |
+
|
85 |
+
|
86 |
+
def train_loop(net,
|
87 |
+
train_dataset,
|
88 |
+
device,
|
89 |
+
save_directory,
|
90 |
+
batch_size=32,
|
91 |
+
steps=300000,
|
92 |
+
epochs_per_save=1,
|
93 |
+
lang="en",
|
94 |
+
lr=0.0001,
|
95 |
+
warmup_steps=4000,
|
96 |
+
path_to_checkpoint=None,
|
97 |
+
fine_tune=False,
|
98 |
+
resume=False):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
resume: whether to resume from the most recent checkpoint
|
102 |
+
warmup_steps: how long the learning rate should increase before it reaches the specified value
|
103 |
+
steps: How many steps to train
|
104 |
+
lr: The initial learning rate for the optimiser
|
105 |
+
path_to_checkpoint: reloads a checkpoint to continue training from there
|
106 |
+
fine_tune: whether to load everything from a checkpoint, or only the model parameters
|
107 |
+
lang: language of the synthesis
|
108 |
+
net: Model to train
|
109 |
+
train_dataset: Pytorch Dataset Object for train data
|
110 |
+
device: Device to put the loaded tensors on
|
111 |
+
save_directory: Where to save the checkpoints
|
112 |
+
batch_size: How many elements should be loaded at once
|
113 |
+
epochs_per_save: how many epochs to train in between checkpoints
|
114 |
+
|
115 |
+
"""
|
116 |
+
net = net.to(device)
|
117 |
+
|
118 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
119 |
+
train_loader = DataLoader(batch_size=batch_size,
|
120 |
+
dataset=train_dataset,
|
121 |
+
drop_last=True,
|
122 |
+
num_workers=8,
|
123 |
+
pin_memory=True,
|
124 |
+
shuffle=True,
|
125 |
+
prefetch_factor=8,
|
126 |
+
collate_fn=collate_and_pad,
|
127 |
+
persistent_workers=True)
|
128 |
+
default_embedding = None
|
129 |
+
for index in range(20): # slicing is not implemented for datasets, so this detour is needed.
|
130 |
+
if default_embedding is None:
|
131 |
+
default_embedding = train_dataset[index][7].squeeze()
|
132 |
+
else:
|
133 |
+
default_embedding = default_embedding + train_dataset[index][7].squeeze()
|
134 |
+
default_embedding = (default_embedding / len(train_dataset)).to(device)
|
135 |
+
# default speaker embedding for inference is the average of the first 20 speaker embeddings. So if you use multiple datasets combined,
|
136 |
+
# put a single speaker one with the nicest voice first into the concat dataset.
|
137 |
+
step_counter = 0
|
138 |
+
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
139 |
+
scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
|
140 |
+
scaler = GradScaler()
|
141 |
+
epoch = 0
|
142 |
+
if resume:
|
143 |
+
path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
|
144 |
+
if path_to_checkpoint is not None:
|
145 |
+
check_dict = torch.load(path_to_checkpoint, map_location=device)
|
146 |
+
net.load_state_dict(check_dict["model"])
|
147 |
+
if not fine_tune:
|
148 |
+
optimizer.load_state_dict(check_dict["optimizer"])
|
149 |
+
scheduler.load_state_dict(check_dict["scheduler"])
|
150 |
+
step_counter = check_dict["step_counter"]
|
151 |
+
scaler.load_state_dict(check_dict["scaler"])
|
152 |
+
start_time = time.time()
|
153 |
+
while True:
|
154 |
+
net.train()
|
155 |
+
epoch += 1
|
156 |
+
optimizer.zero_grad()
|
157 |
+
train_losses_this_epoch = list()
|
158 |
+
for batch in tqdm(train_loader):
|
159 |
+
with autocast():
|
160 |
+
train_loss = net(text_tensors=batch[0].to(device),
|
161 |
+
text_lengths=batch[1].to(device),
|
162 |
+
gold_speech=batch[2].to(device),
|
163 |
+
speech_lengths=batch[3].to(device),
|
164 |
+
gold_durations=batch[4].to(device),
|
165 |
+
gold_pitch=batch[6].to(device), # mind the switched order
|
166 |
+
gold_energy=batch[5].to(device), # mind the switched order
|
167 |
+
utterance_embedding=batch[7].to(device),
|
168 |
+
lang_ids=batch[8].to(device),
|
169 |
+
return_mels=False)
|
170 |
+
train_losses_this_epoch.append(train_loss.item())
|
171 |
+
|
172 |
+
optimizer.zero_grad()
|
173 |
+
scaler.scale(train_loss).backward()
|
174 |
+
del train_loss
|
175 |
+
step_counter += 1
|
176 |
+
scaler.unscale_(optimizer)
|
177 |
+
torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
|
178 |
+
scaler.step(optimizer)
|
179 |
+
scaler.update()
|
180 |
+
scheduler.step()
|
181 |
+
|
182 |
+
net.eval()
|
183 |
+
if epoch % epochs_per_save == 0:
|
184 |
+
torch.save({
|
185 |
+
"model" : net.state_dict(),
|
186 |
+
"optimizer" : optimizer.state_dict(),
|
187 |
+
"step_counter": step_counter,
|
188 |
+
"scaler" : scaler.state_dict(),
|
189 |
+
"scheduler" : scheduler.state_dict(),
|
190 |
+
"default_emb" : default_embedding,
|
191 |
+
}, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter)))
|
192 |
+
delete_old_checkpoints(save_directory, keep=5)
|
193 |
+
plot_progress_spec(net, device, save_dir=save_directory, step=step_counter, lang=lang, default_emb=default_embedding)
|
194 |
+
if step_counter > steps:
|
195 |
+
# DONE
|
196 |
+
return
|
197 |
+
print("Epoch: {}".format(epoch))
|
198 |
+
print("Train Loss: {}".format(sum(train_losses_this_epoch) / len(train_losses_this_epoch)))
|
199 |
+
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
|
200 |
+
print("Steps: {}".format(step_counter))
|
201 |
+
net.train()
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/fastspeech2_train_loop_ctc.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
|
5 |
+
import librosa.display as lbd
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch
|
8 |
+
import torch.multiprocessing
|
9 |
+
import torch.multiprocessing
|
10 |
+
from torch.cuda.amp import GradScaler
|
11 |
+
from torch.nn.utils.rnn import pad_sequence
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
|
15 |
+
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
|
16 |
+
from Utility.WarmupScheduler import WarmupScheduler
|
17 |
+
from Utility.utils import cumsum_durations
|
18 |
+
from Utility.utils import delete_old_checkpoints
|
19 |
+
from Utility.utils import get_most_recent_checkpoint
|
20 |
+
|
21 |
+
|
22 |
+
def plot_progress_spec(net, device, save_dir, step, lang):
|
23 |
+
tf = ArticulatoryCombinedTextFrontend(language=lang)
|
24 |
+
sentence = ""
|
25 |
+
if lang == "en":
|
26 |
+
sentence = "This is a complex sentence, it even has a pause!"
|
27 |
+
elif lang == "de":
|
28 |
+
sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
|
29 |
+
elif lang == "el":
|
30 |
+
sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
|
31 |
+
elif lang == "es":
|
32 |
+
sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
|
33 |
+
elif lang == "fi":
|
34 |
+
sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
|
35 |
+
elif lang == "ru":
|
36 |
+
sentence = "Это сложное предложение, в нем даже есть пауза!"
|
37 |
+
elif lang == "hu":
|
38 |
+
sentence = "Ez egy összetett mondat, még szünet is van benne!"
|
39 |
+
elif lang == "nl":
|
40 |
+
sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
|
41 |
+
elif lang == "fr":
|
42 |
+
sentence = "C'est une phrase complexe, elle a même une pause !"
|
43 |
+
phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
|
44 |
+
spec, durations, *_ = net.inference(text=phoneme_vector, return_duration_pitch_energy=True)
|
45 |
+
spec = spec.transpose(0, 1).to("cpu").numpy()
|
46 |
+
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
|
47 |
+
if not os.path.exists(os.path.join(save_dir, "spec")):
|
48 |
+
os.makedirs(os.path.join(save_dir, "spec"))
|
49 |
+
fig, ax = plt.subplots(nrows=1, ncols=1)
|
50 |
+
lbd.specshow(spec,
|
51 |
+
ax=ax,
|
52 |
+
sr=16000,
|
53 |
+
cmap='GnBu',
|
54 |
+
y_axis='mel',
|
55 |
+
x_axis=None,
|
56 |
+
hop_length=256)
|
57 |
+
ax.yaxis.set_visible(False)
|
58 |
+
ax.set_xticks(duration_splits, minor=True)
|
59 |
+
ax.xaxis.grid(True, which='minor')
|
60 |
+
ax.set_xticks(label_positions, minor=False)
|
61 |
+
ax.set_xticklabels(tf.get_phone_string(sentence))
|
62 |
+
ax.set_title(sentence)
|
63 |
+
plt.savefig(os.path.join(os.path.join(save_dir, "spec"), str(step) + ".png"))
|
64 |
+
plt.clf()
|
65 |
+
plt.close()
|
66 |
+
|
67 |
+
|
68 |
+
def train_loop(net,
|
69 |
+
train_sentences,
|
70 |
+
device,
|
71 |
+
save_directory,
|
72 |
+
aligner_checkpoint,
|
73 |
+
batch_size=32,
|
74 |
+
steps=300000,
|
75 |
+
epochs_per_save=5,
|
76 |
+
lang="en",
|
77 |
+
lr=0.0001,
|
78 |
+
warmup_steps=4000,
|
79 |
+
path_to_checkpoint=None,
|
80 |
+
fine_tune=False,
|
81 |
+
resume=False):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
resume: whether to resume from the most recent checkpoint
|
85 |
+
warmup_steps: how long the learning rate should increase before it reaches the specified value
|
86 |
+
steps: How many steps to train
|
87 |
+
lr: The initial learning rate for the optimiser
|
88 |
+
path_to_checkpoint: reloads a checkpoint to continue training from there
|
89 |
+
fine_tune: whether to load everything from a checkpoint, or only the model parameters
|
90 |
+
lang: language of the synthesis and of the train sentences
|
91 |
+
net: Model to train
|
92 |
+
train_sentences: list of (string) sentences the CTC objective should be learned on
|
93 |
+
device: Device to put the loaded tensors on
|
94 |
+
save_directory: Where to save the checkpoints
|
95 |
+
batch_size: How many elements should be loaded at once
|
96 |
+
epochs_per_save: how many epochs to train in between checkpoints
|
97 |
+
|
98 |
+
"""
|
99 |
+
net = net.to(device)
|
100 |
+
|
101 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
102 |
+
text_to_art_vec = ArticulatoryCombinedTextFrontend(language=lang)
|
103 |
+
asr_aligner = Aligner().to(device)
|
104 |
+
check_dict = torch.load(os.path.join(aligner_checkpoint), map_location=device)
|
105 |
+
asr_aligner.load_state_dict(check_dict["asr_model"])
|
106 |
+
net.stop_gradient_from_energy_predictor = False
|
107 |
+
net.stop_gradient_from_pitch_predictor = False
|
108 |
+
step_counter = 0
|
109 |
+
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
110 |
+
scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
|
111 |
+
scaler = GradScaler()
|
112 |
+
epoch = 0
|
113 |
+
if resume:
|
114 |
+
path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
|
115 |
+
if path_to_checkpoint is not None:
|
116 |
+
check_dict = torch.load(path_to_checkpoint, map_location=device)
|
117 |
+
net.load_state_dict(check_dict["model"])
|
118 |
+
if not fine_tune:
|
119 |
+
optimizer.load_state_dict(check_dict["optimizer"])
|
120 |
+
scheduler.load_state_dict(check_dict["scheduler"])
|
121 |
+
step_counter = check_dict["step_counter"]
|
122 |
+
scaler.load_state_dict(check_dict["scaler"])
|
123 |
+
start_time = time.time()
|
124 |
+
while True:
|
125 |
+
net.train()
|
126 |
+
epoch += 1
|
127 |
+
optimizer.zero_grad()
|
128 |
+
train_losses_this_epoch = list()
|
129 |
+
random.shuffle(train_sentences)
|
130 |
+
batch_of_text_vecs = list()
|
131 |
+
batch_of_tokens = list()
|
132 |
+
|
133 |
+
for sentence in tqdm(train_sentences):
|
134 |
+
if sentence.strip() == "":
|
135 |
+
continue
|
136 |
+
|
137 |
+
phonemes = text_to_art_vec.get_phone_string(sentence)
|
138 |
+
# collect batch of texts
|
139 |
+
batch_of_text_vecs.append(text_to_art_vec.string_to_tensor(phonemes, input_phonemes=True).squeeze(0).to(device))
|
140 |
+
|
141 |
+
# collect batch of tokens
|
142 |
+
tokens = list()
|
143 |
+
for phone in phonemes:
|
144 |
+
tokens.append(text_to_art_vec.phone_to_id[phone])
|
145 |
+
tokens = torch.LongTensor(tokens).to(device)
|
146 |
+
batch_of_tokens.append(tokens)
|
147 |
+
|
148 |
+
if len(batch_of_tokens) == batch_size:
|
149 |
+
token_batch = pad_sequence(batch_of_tokens, batch_first=True)
|
150 |
+
token_lens = torch.LongTensor([len(x) for x in batch_of_tokens]).to(device)
|
151 |
+
text_batch = pad_sequence(batch_of_text_vecs, batch_first=True)
|
152 |
+
spec_batch, d_outs = net.batch_inference(texts=text_batch, text_lens=token_lens)
|
153 |
+
spec_lens = torch.LongTensor([sum(x) for x in d_outs]).to(device)
|
154 |
+
|
155 |
+
asr_pred = asr_aligner(spec_batch, spec_lens)
|
156 |
+
train_loss = asr_aligner.ctc_loss(asr_pred.transpose(0, 1).log_softmax(2), token_batch, spec_lens, token_lens)
|
157 |
+
train_losses_this_epoch.append(train_loss.item())
|
158 |
+
|
159 |
+
optimizer.zero_grad()
|
160 |
+
asr_aligner.zero_grad()
|
161 |
+
scaler.scale(train_loss).backward()
|
162 |
+
del train_loss
|
163 |
+
step_counter += 1
|
164 |
+
scaler.unscale_(optimizer)
|
165 |
+
torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
|
166 |
+
scaler.step(optimizer)
|
167 |
+
scaler.update()
|
168 |
+
scheduler.step()
|
169 |
+
batch_of_tokens = list()
|
170 |
+
batch_of_text_vecs = list()
|
171 |
+
|
172 |
+
net.eval()
|
173 |
+
if epoch % epochs_per_save == 0:
|
174 |
+
torch.save({
|
175 |
+
"model" : net.state_dict(),
|
176 |
+
"optimizer" : optimizer.state_dict(),
|
177 |
+
"step_counter": step_counter,
|
178 |
+
"scaler" : scaler.state_dict(),
|
179 |
+
"scheduler" : scheduler.state_dict(),
|
180 |
+
}, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter)))
|
181 |
+
delete_old_checkpoints(save_directory, keep=5)
|
182 |
+
with torch.no_grad():
|
183 |
+
plot_progress_spec(net, device, save_dir=save_directory, step=step_counter, lang=lang)
|
184 |
+
if step_counter > steps:
|
185 |
+
# DONE
|
186 |
+
return
|
187 |
+
print("Epoch: {}".format(epoch))
|
188 |
+
print("Train Loss: {}".format(sum(train_losses_this_epoch) / len(train_losses_this_epoch)))
|
189 |
+
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
|
190 |
+
print("Steps: {}".format(step_counter))
|
191 |
+
net.train()
|
TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa.display as lbd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import torch
|
4 |
+
import torch.multiprocessing
|
5 |
+
from torch.cuda.amp import GradScaler
|
6 |
+
from torch.cuda.amp import autocast
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
from torch.utils.data.dataloader import DataLoader
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import ArticulatoryCombinedTextFrontend
|
12 |
+
from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id
|
13 |
+
from Utility.WarmupScheduler import WarmupScheduler
|
14 |
+
from Utility.path_to_transcript_dicts import *
|
15 |
+
from Utility.utils import cumsum_durations
|
16 |
+
from Utility.utils import delete_old_checkpoints
|
17 |
+
from Utility.utils import get_most_recent_checkpoint
|
18 |
+
|
19 |
+
|
20 |
+
def train_loop(net,
|
21 |
+
datasets,
|
22 |
+
device,
|
23 |
+
save_directory,
|
24 |
+
batch_size,
|
25 |
+
steps,
|
26 |
+
steps_per_checkpoint,
|
27 |
+
lr,
|
28 |
+
path_to_checkpoint,
|
29 |
+
resume=False,
|
30 |
+
warmup_steps=4000):
|
31 |
+
# ============
|
32 |
+
# Preparations
|
33 |
+
# ============
|
34 |
+
net = net.to(device)
|
35 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
36 |
+
train_loaders = list()
|
37 |
+
train_iters = list()
|
38 |
+
for dataset in datasets:
|
39 |
+
train_loaders.append(DataLoader(batch_size=batch_size,
|
40 |
+
dataset=dataset,
|
41 |
+
drop_last=True,
|
42 |
+
num_workers=2,
|
43 |
+
pin_memory=True,
|
44 |
+
shuffle=True,
|
45 |
+
prefetch_factor=5,
|
46 |
+
collate_fn=collate_and_pad,
|
47 |
+
persistent_workers=True))
|
48 |
+
train_iters.append(iter(train_loaders[-1]))
|
49 |
+
default_embeddings = {"en": None, "de": None, "el": None, "es": None, "fi": None, "ru": None, "hu": None, "nl": None, "fr": None}
|
50 |
+
for index, lang in enumerate(["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]):
|
51 |
+
default_embedding = None
|
52 |
+
for datapoint in datasets[index]:
|
53 |
+
if default_embedding is None:
|
54 |
+
default_embedding = datapoint[7].squeeze()
|
55 |
+
else:
|
56 |
+
default_embedding = default_embedding + datapoint[7].squeeze()
|
57 |
+
default_embeddings[lang] = (default_embedding / len(datasets[index])).to(device)
|
58 |
+
optimizer = torch.optim.RAdam(net.parameters(), lr=lr, eps=1.0e-06, weight_decay=0.0)
|
59 |
+
grad_scaler = GradScaler()
|
60 |
+
scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
|
61 |
+
if resume:
|
62 |
+
previous_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory)
|
63 |
+
if previous_checkpoint is not None:
|
64 |
+
path_to_checkpoint = previous_checkpoint
|
65 |
+
else:
|
66 |
+
raise RuntimeError(f"No checkpoint found that can be resumed from in {save_directory}")
|
67 |
+
step_counter = 0
|
68 |
+
train_losses_total = list()
|
69 |
+
if path_to_checkpoint is not None:
|
70 |
+
check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
|
71 |
+
net.load_state_dict(check_dict["model"])
|
72 |
+
if resume:
|
73 |
+
optimizer.load_state_dict(check_dict["optimizer"])
|
74 |
+
step_counter = check_dict["step_counter"]
|
75 |
+
grad_scaler.load_state_dict(check_dict["scaler"])
|
76 |
+
scheduler.load_state_dict(check_dict["scheduler"])
|
77 |
+
if step_counter > steps:
|
78 |
+
print("Desired steps already reached in loaded checkpoint.")
|
79 |
+
return
|
80 |
+
|
81 |
+
net.train()
|
82 |
+
# =============================
|
83 |
+
# Actual train loop starts here
|
84 |
+
# =============================
|
85 |
+
for step in tqdm(range(step_counter, steps)):
|
86 |
+
batches = []
|
87 |
+
for index in range(len(datasets)):
|
88 |
+
# we get one batch for each task (i.e. language in this case)
|
89 |
+
try:
|
90 |
+
batch = next(train_iters[index])
|
91 |
+
batches.append(batch)
|
92 |
+
except StopIteration:
|
93 |
+
train_iters[index] = iter(train_loaders[index])
|
94 |
+
batch = next(train_iters[index])
|
95 |
+
batches.append(batch)
|
96 |
+
train_loss = 0.0
|
97 |
+
for batch in batches:
|
98 |
+
with autocast():
|
99 |
+
# we sum the loss for each task, as we would do for the
|
100 |
+
# second order regular MAML, but we do it only over one
|
101 |
+
# step (i.e. iterations of inner loop = 1)
|
102 |
+
train_loss = train_loss + net(text_tensors=batch[0].to(device),
|
103 |
+
text_lengths=batch[1].to(device),
|
104 |
+
gold_speech=batch[2].to(device),
|
105 |
+
speech_lengths=batch[3].to(device),
|
106 |
+
gold_durations=batch[4].to(device),
|
107 |
+
gold_pitch=batch[6].to(device), # mind the switched order
|
108 |
+
gold_energy=batch[5].to(device), # mind the switched order
|
109 |
+
utterance_embedding=batch[7].to(device),
|
110 |
+
lang_ids=batch[8].to(device),
|
111 |
+
return_mels=False)
|
112 |
+
# then we directly update our meta-parameters without
|
113 |
+
# the need for any task specific parameters
|
114 |
+
train_losses_total.append(train_loss.item())
|
115 |
+
optimizer.zero_grad()
|
116 |
+
grad_scaler.scale(train_loss).backward()
|
117 |
+
grad_scaler.unscale_(optimizer)
|
118 |
+
torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0, error_if_nonfinite=False)
|
119 |
+
grad_scaler.step(optimizer)
|
120 |
+
grad_scaler.update()
|
121 |
+
scheduler.step()
|
122 |
+
|
123 |
+
if step % steps_per_checkpoint == 0:
|
124 |
+
# ==============================
|
125 |
+
# Enough steps for some insights
|
126 |
+
# ==============================
|
127 |
+
net.eval()
|
128 |
+
print(f"Total Loss: {round(sum(train_losses_total) / len(train_losses_total), 3)}")
|
129 |
+
train_losses_total = list()
|
130 |
+
torch.save({
|
131 |
+
"model" : net.state_dict(),
|
132 |
+
"optimizer" : optimizer.state_dict(),
|
133 |
+
"scaler" : grad_scaler.state_dict(),
|
134 |
+
"scheduler" : scheduler.state_dict(),
|
135 |
+
"step_counter": step,
|
136 |
+
"default_emb" : default_embeddings["en"]
|
137 |
+
},
|
138 |
+
os.path.join(save_directory, "checkpoint_{}.pt".format(step)))
|
139 |
+
delete_old_checkpoints(save_directory, keep=5)
|
140 |
+
for lang in ["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]:
|
141 |
+
plot_progress_spec(net=net,
|
142 |
+
device=device,
|
143 |
+
lang=lang,
|
144 |
+
save_dir=save_directory,
|
145 |
+
step=step,
|
146 |
+
utt_embeds=default_embeddings)
|
147 |
+
net.train()
|
148 |
+
|
149 |
+
|
150 |
+
@torch.inference_mode()
|
151 |
+
def plot_progress_spec(net, device, save_dir, step, lang, utt_embeds):
|
152 |
+
tf = ArticulatoryCombinedTextFrontend(language=lang)
|
153 |
+
sentence = ""
|
154 |
+
default_embed = utt_embeds[lang]
|
155 |
+
if lang == "en":
|
156 |
+
sentence = "This is a complex sentence, it even has a pause!"
|
157 |
+
elif lang == "de":
|
158 |
+
sentence = "Dies ist ein komplexer Satz, er hat sogar eine Pause!"
|
159 |
+
elif lang == "el":
|
160 |
+
sentence = "Αυτή είναι μια σύνθετη πρόταση, έχει ακόμη και παύση!"
|
161 |
+
elif lang == "es":
|
162 |
+
sentence = "Esta es una oración compleja, ¡incluso tiene una pausa!"
|
163 |
+
elif lang == "fi":
|
164 |
+
sentence = "Tämä on monimutkainen lause, sillä on jopa tauko!"
|
165 |
+
elif lang == "ru":
|
166 |
+
sentence = "Это сложное предложение, в нем даже есть пауза!"
|
167 |
+
elif lang == "hu":
|
168 |
+
sentence = "Ez egy összetett mondat, még szünet is van benne!"
|
169 |
+
elif lang == "nl":
|
170 |
+
sentence = "Dit is een complexe zin, er zit zelfs een pauze in!"
|
171 |
+
elif lang == "fr":
|
172 |
+
sentence = "C'est une phrase complexe, elle a même une pause !"
|
173 |
+
phoneme_vector = tf.string_to_tensor(sentence).squeeze(0).to(device)
|
174 |
+
spec, durations, *_ = net.inference(text=phoneme_vector,
|
175 |
+
return_duration_pitch_energy=True,
|
176 |
+
utterance_embedding=default_embed,
|
177 |
+
lang_id=get_language_id(lang).to(device))
|
178 |
+
spec = spec.transpose(0, 1).to("cpu").numpy()
|
179 |
+
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
|
180 |
+
if not os.path.exists(os.path.join(save_dir, "spec")):
|
181 |
+
os.makedirs(os.path.join(save_dir, "spec"))
|
182 |
+
fig, ax = plt.subplots(nrows=1, ncols=1)
|
183 |
+
lbd.specshow(spec,
|
184 |
+
ax=ax,
|
185 |
+
sr=16000,
|
186 |
+
cmap='GnBu',
|
187 |
+
y_axis='mel',
|
188 |
+
x_axis=None,
|
189 |
+
hop_length=256)
|
190 |
+
ax.yaxis.set_visible(False)
|
191 |
+
ax.set_xticks(duration_splits, minor=True)
|
192 |
+
ax.xaxis.grid(True, which='minor')
|
193 |
+
ax.set_xticks(label_positions, minor=False)
|
194 |
+
ax.set_xticklabels(tf.get_phone_string(sentence))
|
195 |
+
ax.set_title(sentence)
|
196 |
+
plt.savefig(os.path.join(os.path.join(save_dir, "spec"), f"{step}_{lang}.png"))
|
197 |
+
plt.clf()
|
198 |
+
plt.close()
|
199 |
+
|
200 |
+
|
201 |
+
def collate_and_pad(batch):
|
202 |
+
# text, text_len, speech, speech_len, durations, energy, pitch, utterance condition, language_id
|
203 |
+
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
|
204 |
+
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
|
205 |
+
pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
|
206 |
+
torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
|
207 |
+
pad_sequence([datapoint[4] for datapoint in batch], batch_first=True),
|
208 |
+
pad_sequence([datapoint[5] for datapoint in batch], batch_first=True),
|
209 |
+
pad_sequence([datapoint[6] for datapoint in batch], batch_first=True),
|
210 |
+
torch.stack([datapoint[7] for datapoint in batch]).squeeze(),
|
211 |
+
torch.stack([datapoint[8] for datapoint in batch]))
|
TrainingInterfaces/Text_to_Spectrogram/__init__.py
ADDED
File without changes
|
TrainingInterfaces/__init__.py
ADDED
File without changes
|