Spaces:
Build error
Build error
TheComputerMan
commited on
Commit
•
9aefa26
1
Parent(s):
dcc9625
Upload InferenceFastSpeech2.py
Browse files- InferenceFastSpeech2.py +256 -0
InferenceFastSpeech2.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from Layers.Conformer import Conformer
|
6 |
+
from Layers.DurationPredictor import DurationPredictor
|
7 |
+
from Layers.LengthRegulator import LengthRegulator
|
8 |
+
from Layers.PostNet import PostNet
|
9 |
+
from Layers.VariancePredictor import VariancePredictor
|
10 |
+
from Utility.utils import make_non_pad_mask
|
11 |
+
from Utility.utils import make_pad_mask
|
12 |
+
|
13 |
+
|
14 |
+
class FastSpeech2(torch.nn.Module, ABC):
|
15 |
+
|
16 |
+
def __init__(self, # network structure related
|
17 |
+
weights,
|
18 |
+
idim=66,
|
19 |
+
odim=80,
|
20 |
+
adim=384,
|
21 |
+
aheads=4,
|
22 |
+
elayers=6,
|
23 |
+
eunits=1536,
|
24 |
+
dlayers=6,
|
25 |
+
dunits=1536,
|
26 |
+
postnet_layers=5,
|
27 |
+
postnet_chans=256,
|
28 |
+
postnet_filts=5,
|
29 |
+
positionwise_conv_kernel_size=1,
|
30 |
+
use_scaled_pos_enc=True,
|
31 |
+
use_batch_norm=True,
|
32 |
+
encoder_normalize_before=True,
|
33 |
+
decoder_normalize_before=True,
|
34 |
+
encoder_concat_after=False,
|
35 |
+
decoder_concat_after=False,
|
36 |
+
reduction_factor=1,
|
37 |
+
# encoder / decoder
|
38 |
+
use_macaron_style_in_conformer=True,
|
39 |
+
use_cnn_in_conformer=True,
|
40 |
+
conformer_enc_kernel_size=7,
|
41 |
+
conformer_dec_kernel_size=31,
|
42 |
+
# duration predictor
|
43 |
+
duration_predictor_layers=2,
|
44 |
+
duration_predictor_chans=256,
|
45 |
+
duration_predictor_kernel_size=3,
|
46 |
+
# energy predictor
|
47 |
+
energy_predictor_layers=2,
|
48 |
+
energy_predictor_chans=256,
|
49 |
+
energy_predictor_kernel_size=3,
|
50 |
+
energy_predictor_dropout=0.5,
|
51 |
+
energy_embed_kernel_size=1,
|
52 |
+
energy_embed_dropout=0.0,
|
53 |
+
stop_gradient_from_energy_predictor=True,
|
54 |
+
# pitch predictor
|
55 |
+
pitch_predictor_layers=5,
|
56 |
+
pitch_predictor_chans=256,
|
57 |
+
pitch_predictor_kernel_size=5,
|
58 |
+
pitch_predictor_dropout=0.5,
|
59 |
+
pitch_embed_kernel_size=1,
|
60 |
+
pitch_embed_dropout=0.0,
|
61 |
+
stop_gradient_from_pitch_predictor=True,
|
62 |
+
# training related
|
63 |
+
transformer_enc_dropout_rate=0.2,
|
64 |
+
transformer_enc_positional_dropout_rate=0.2,
|
65 |
+
transformer_enc_attn_dropout_rate=0.2,
|
66 |
+
transformer_dec_dropout_rate=0.2,
|
67 |
+
transformer_dec_positional_dropout_rate=0.2,
|
68 |
+
transformer_dec_attn_dropout_rate=0.2,
|
69 |
+
duration_predictor_dropout_rate=0.2,
|
70 |
+
postnet_dropout_rate=0.5,
|
71 |
+
# additional features
|
72 |
+
utt_embed_dim=704,
|
73 |
+
connect_utt_emb_at_encoder_out=True,
|
74 |
+
lang_embs=100):
|
75 |
+
super().__init__()
|
76 |
+
self.idim = idim
|
77 |
+
self.odim = odim
|
78 |
+
self.reduction_factor = reduction_factor
|
79 |
+
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
|
80 |
+
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
|
81 |
+
self.use_scaled_pos_enc = use_scaled_pos_enc
|
82 |
+
embed = torch.nn.Sequential(torch.nn.Linear(idim, 100),
|
83 |
+
torch.nn.Tanh(),
|
84 |
+
torch.nn.Linear(100, adim))
|
85 |
+
self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers,
|
86 |
+
input_layer=embed, dropout_rate=transformer_enc_dropout_rate,
|
87 |
+
positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate,
|
88 |
+
normalize_before=encoder_normalize_before, concat_after=encoder_concat_after,
|
89 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer,
|
90 |
+
use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False,
|
91 |
+
utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs)
|
92 |
+
self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers,
|
93 |
+
n_chans=duration_predictor_chans,
|
94 |
+
kernel_size=duration_predictor_kernel_size,
|
95 |
+
dropout_rate=duration_predictor_dropout_rate, )
|
96 |
+
self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers,
|
97 |
+
n_chans=pitch_predictor_chans,
|
98 |
+
kernel_size=pitch_predictor_kernel_size,
|
99 |
+
dropout_rate=pitch_predictor_dropout)
|
100 |
+
self.pitch_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
|
101 |
+
kernel_size=pitch_embed_kernel_size,
|
102 |
+
padding=(pitch_embed_kernel_size - 1) // 2),
|
103 |
+
torch.nn.Dropout(pitch_embed_dropout))
|
104 |
+
self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers,
|
105 |
+
n_chans=energy_predictor_chans,
|
106 |
+
kernel_size=energy_predictor_kernel_size,
|
107 |
+
dropout_rate=energy_predictor_dropout)
|
108 |
+
self.energy_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
|
109 |
+
kernel_size=energy_embed_kernel_size,
|
110 |
+
padding=(energy_embed_kernel_size - 1) // 2),
|
111 |
+
torch.nn.Dropout(energy_embed_dropout))
|
112 |
+
self.length_regulator = LengthRegulator()
|
113 |
+
self.decoder = Conformer(idim=0,
|
114 |
+
attention_dim=adim,
|
115 |
+
attention_heads=aheads,
|
116 |
+
linear_units=dunits,
|
117 |
+
num_blocks=dlayers,
|
118 |
+
input_layer=None,
|
119 |
+
dropout_rate=transformer_dec_dropout_rate,
|
120 |
+
positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
121 |
+
attention_dropout_rate=transformer_dec_attn_dropout_rate,
|
122 |
+
normalize_before=decoder_normalize_before,
|
123 |
+
concat_after=decoder_concat_after,
|
124 |
+
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
125 |
+
macaron_style=use_macaron_style_in_conformer,
|
126 |
+
use_cnn_module=use_cnn_in_conformer,
|
127 |
+
cnn_module_kernel=conformer_dec_kernel_size)
|
128 |
+
self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
|
129 |
+
self.postnet = PostNet(idim=idim,
|
130 |
+
odim=odim,
|
131 |
+
n_layers=postnet_layers,
|
132 |
+
n_chans=postnet_chans,
|
133 |
+
n_filts=postnet_filts,
|
134 |
+
use_batch_norm=use_batch_norm,
|
135 |
+
dropout_rate=postnet_dropout_rate)
|
136 |
+
self.load_state_dict(weights)
|
137 |
+
|
138 |
+
def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None,
|
139 |
+
gold_durations=None, gold_pitch=None, gold_energy=None,
|
140 |
+
is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None):
|
141 |
+
# forward encoder
|
142 |
+
text_masks = self._source_mask(text_lens)
|
143 |
+
|
144 |
+
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim)
|
145 |
+
|
146 |
+
# forward duration predictor and variance predictors
|
147 |
+
duration_masks = make_pad_mask(text_lens, device=text_lens.device)
|
148 |
+
|
149 |
+
if self.stop_gradient_from_pitch_predictor:
|
150 |
+
pitch_predictions = self.pitch_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
|
151 |
+
else:
|
152 |
+
pitch_predictions = self.pitch_predictor(encoded_texts, duration_masks.unsqueeze(-1))
|
153 |
+
|
154 |
+
if self.stop_gradient_from_energy_predictor:
|
155 |
+
energy_predictions = self.energy_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
|
156 |
+
else:
|
157 |
+
energy_predictions = self.energy_predictor(encoded_texts, duration_masks.unsqueeze(-1))
|
158 |
+
|
159 |
+
if is_inference:
|
160 |
+
if gold_durations is not None:
|
161 |
+
duration_predictions = gold_durations
|
162 |
+
else:
|
163 |
+
duration_predictions = self.duration_predictor.inference(encoded_texts, duration_masks)
|
164 |
+
if gold_pitch is not None:
|
165 |
+
pitch_predictions = gold_pitch
|
166 |
+
if gold_energy is not None:
|
167 |
+
energy_predictions = gold_energy
|
168 |
+
pitch_embeddings = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
|
169 |
+
energy_embeddings = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
|
170 |
+
encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
|
171 |
+
encoded_texts = self.length_regulator(encoded_texts, duration_predictions, alpha)
|
172 |
+
else:
|
173 |
+
duration_predictions = self.duration_predictor(encoded_texts, duration_masks)
|
174 |
+
|
175 |
+
# use groundtruth in training
|
176 |
+
pitch_embeddings = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
|
177 |
+
energy_embeddings = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
|
178 |
+
encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
|
179 |
+
encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim)
|
180 |
+
|
181 |
+
# forward decoder
|
182 |
+
if speech_lens is not None and not is_inference:
|
183 |
+
if self.reduction_factor > 1:
|
184 |
+
olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens])
|
185 |
+
else:
|
186 |
+
olens_in = speech_lens
|
187 |
+
h_masks = self._source_mask(olens_in)
|
188 |
+
else:
|
189 |
+
h_masks = None
|
190 |
+
zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim)
|
191 |
+
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
|
192 |
+
|
193 |
+
# postnet -> (B, Lmax//r * r, odim)
|
194 |
+
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
|
195 |
+
|
196 |
+
return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions
|
197 |
+
|
198 |
+
@torch.no_grad()
|
199 |
+
def forward(self,
|
200 |
+
text,
|
201 |
+
speech=None,
|
202 |
+
durations=None,
|
203 |
+
pitch=None,
|
204 |
+
energy=None,
|
205 |
+
utterance_embedding=None,
|
206 |
+
return_duration_pitch_energy=False,
|
207 |
+
lang_id=None):
|
208 |
+
"""
|
209 |
+
Generate the sequence of features given the sequences of characters.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
text: Input sequence of characters
|
213 |
+
speech: Feature sequence to extract style
|
214 |
+
durations: Groundtruth of duration
|
215 |
+
pitch: Groundtruth of token-averaged pitch
|
216 |
+
energy: Groundtruth of token-averaged energy
|
217 |
+
return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
|
218 |
+
utterance_embedding: embedding of utterance wide parameters
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
Mel Spectrogram
|
222 |
+
|
223 |
+
"""
|
224 |
+
self.eval()
|
225 |
+
# setup batch axis
|
226 |
+
ilens = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device)
|
227 |
+
if speech is not None:
|
228 |
+
gold_speech = speech.unsqueeze(0)
|
229 |
+
else:
|
230 |
+
gold_speech = None
|
231 |
+
if durations is not None:
|
232 |
+
durations = durations.unsqueeze(0)
|
233 |
+
if pitch is not None:
|
234 |
+
pitch = pitch.unsqueeze(0)
|
235 |
+
if energy is not None:
|
236 |
+
energy = energy.unsqueeze(0)
|
237 |
+
if lang_id is not None:
|
238 |
+
lang_id = lang_id.unsqueeze(0)
|
239 |
+
|
240 |
+
before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(text.unsqueeze(0),
|
241 |
+
ilens,
|
242 |
+
gold_speech=gold_speech,
|
243 |
+
gold_durations=durations,
|
244 |
+
is_inference=True,
|
245 |
+
gold_pitch=pitch,
|
246 |
+
gold_energy=energy,
|
247 |
+
utterance_embedding=utterance_embedding.unsqueeze(0),
|
248 |
+
lang_ids=lang_id)
|
249 |
+
self.train()
|
250 |
+
if return_duration_pitch_energy:
|
251 |
+
return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0]
|
252 |
+
return after_outs[0]
|
253 |
+
|
254 |
+
def _source_mask(self, ilens):
|
255 |
+
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
|
256 |
+
return x_masks.unsqueeze(-2)
|