TheComputerMan commited on
Commit
9aefa26
1 Parent(s): dcc9625

Upload InferenceFastSpeech2.py

Browse files
Files changed (1) hide show
  1. InferenceFastSpeech2.py +256 -0
InferenceFastSpeech2.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+
5
+ from Layers.Conformer import Conformer
6
+ from Layers.DurationPredictor import DurationPredictor
7
+ from Layers.LengthRegulator import LengthRegulator
8
+ from Layers.PostNet import PostNet
9
+ from Layers.VariancePredictor import VariancePredictor
10
+ from Utility.utils import make_non_pad_mask
11
+ from Utility.utils import make_pad_mask
12
+
13
+
14
+ class FastSpeech2(torch.nn.Module, ABC):
15
+
16
+ def __init__(self, # network structure related
17
+ weights,
18
+ idim=66,
19
+ odim=80,
20
+ adim=384,
21
+ aheads=4,
22
+ elayers=6,
23
+ eunits=1536,
24
+ dlayers=6,
25
+ dunits=1536,
26
+ postnet_layers=5,
27
+ postnet_chans=256,
28
+ postnet_filts=5,
29
+ positionwise_conv_kernel_size=1,
30
+ use_scaled_pos_enc=True,
31
+ use_batch_norm=True,
32
+ encoder_normalize_before=True,
33
+ decoder_normalize_before=True,
34
+ encoder_concat_after=False,
35
+ decoder_concat_after=False,
36
+ reduction_factor=1,
37
+ # encoder / decoder
38
+ use_macaron_style_in_conformer=True,
39
+ use_cnn_in_conformer=True,
40
+ conformer_enc_kernel_size=7,
41
+ conformer_dec_kernel_size=31,
42
+ # duration predictor
43
+ duration_predictor_layers=2,
44
+ duration_predictor_chans=256,
45
+ duration_predictor_kernel_size=3,
46
+ # energy predictor
47
+ energy_predictor_layers=2,
48
+ energy_predictor_chans=256,
49
+ energy_predictor_kernel_size=3,
50
+ energy_predictor_dropout=0.5,
51
+ energy_embed_kernel_size=1,
52
+ energy_embed_dropout=0.0,
53
+ stop_gradient_from_energy_predictor=True,
54
+ # pitch predictor
55
+ pitch_predictor_layers=5,
56
+ pitch_predictor_chans=256,
57
+ pitch_predictor_kernel_size=5,
58
+ pitch_predictor_dropout=0.5,
59
+ pitch_embed_kernel_size=1,
60
+ pitch_embed_dropout=0.0,
61
+ stop_gradient_from_pitch_predictor=True,
62
+ # training related
63
+ transformer_enc_dropout_rate=0.2,
64
+ transformer_enc_positional_dropout_rate=0.2,
65
+ transformer_enc_attn_dropout_rate=0.2,
66
+ transformer_dec_dropout_rate=0.2,
67
+ transformer_dec_positional_dropout_rate=0.2,
68
+ transformer_dec_attn_dropout_rate=0.2,
69
+ duration_predictor_dropout_rate=0.2,
70
+ postnet_dropout_rate=0.5,
71
+ # additional features
72
+ utt_embed_dim=704,
73
+ connect_utt_emb_at_encoder_out=True,
74
+ lang_embs=100):
75
+ super().__init__()
76
+ self.idim = idim
77
+ self.odim = odim
78
+ self.reduction_factor = reduction_factor
79
+ self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
80
+ self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
81
+ self.use_scaled_pos_enc = use_scaled_pos_enc
82
+ embed = torch.nn.Sequential(torch.nn.Linear(idim, 100),
83
+ torch.nn.Tanh(),
84
+ torch.nn.Linear(100, adim))
85
+ self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers,
86
+ input_layer=embed, dropout_rate=transformer_enc_dropout_rate,
87
+ positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate,
88
+ normalize_before=encoder_normalize_before, concat_after=encoder_concat_after,
89
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer,
90
+ use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False,
91
+ utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs)
92
+ self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers,
93
+ n_chans=duration_predictor_chans,
94
+ kernel_size=duration_predictor_kernel_size,
95
+ dropout_rate=duration_predictor_dropout_rate, )
96
+ self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers,
97
+ n_chans=pitch_predictor_chans,
98
+ kernel_size=pitch_predictor_kernel_size,
99
+ dropout_rate=pitch_predictor_dropout)
100
+ self.pitch_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
101
+ kernel_size=pitch_embed_kernel_size,
102
+ padding=(pitch_embed_kernel_size - 1) // 2),
103
+ torch.nn.Dropout(pitch_embed_dropout))
104
+ self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers,
105
+ n_chans=energy_predictor_chans,
106
+ kernel_size=energy_predictor_kernel_size,
107
+ dropout_rate=energy_predictor_dropout)
108
+ self.energy_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim,
109
+ kernel_size=energy_embed_kernel_size,
110
+ padding=(energy_embed_kernel_size - 1) // 2),
111
+ torch.nn.Dropout(energy_embed_dropout))
112
+ self.length_regulator = LengthRegulator()
113
+ self.decoder = Conformer(idim=0,
114
+ attention_dim=adim,
115
+ attention_heads=aheads,
116
+ linear_units=dunits,
117
+ num_blocks=dlayers,
118
+ input_layer=None,
119
+ dropout_rate=transformer_dec_dropout_rate,
120
+ positional_dropout_rate=transformer_dec_positional_dropout_rate,
121
+ attention_dropout_rate=transformer_dec_attn_dropout_rate,
122
+ normalize_before=decoder_normalize_before,
123
+ concat_after=decoder_concat_after,
124
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size,
125
+ macaron_style=use_macaron_style_in_conformer,
126
+ use_cnn_module=use_cnn_in_conformer,
127
+ cnn_module_kernel=conformer_dec_kernel_size)
128
+ self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
129
+ self.postnet = PostNet(idim=idim,
130
+ odim=odim,
131
+ n_layers=postnet_layers,
132
+ n_chans=postnet_chans,
133
+ n_filts=postnet_filts,
134
+ use_batch_norm=use_batch_norm,
135
+ dropout_rate=postnet_dropout_rate)
136
+ self.load_state_dict(weights)
137
+
138
+ def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None,
139
+ gold_durations=None, gold_pitch=None, gold_energy=None,
140
+ is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None):
141
+ # forward encoder
142
+ text_masks = self._source_mask(text_lens)
143
+
144
+ encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim)
145
+
146
+ # forward duration predictor and variance predictors
147
+ duration_masks = make_pad_mask(text_lens, device=text_lens.device)
148
+
149
+ if self.stop_gradient_from_pitch_predictor:
150
+ pitch_predictions = self.pitch_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
151
+ else:
152
+ pitch_predictions = self.pitch_predictor(encoded_texts, duration_masks.unsqueeze(-1))
153
+
154
+ if self.stop_gradient_from_energy_predictor:
155
+ energy_predictions = self.energy_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1))
156
+ else:
157
+ energy_predictions = self.energy_predictor(encoded_texts, duration_masks.unsqueeze(-1))
158
+
159
+ if is_inference:
160
+ if gold_durations is not None:
161
+ duration_predictions = gold_durations
162
+ else:
163
+ duration_predictions = self.duration_predictor.inference(encoded_texts, duration_masks)
164
+ if gold_pitch is not None:
165
+ pitch_predictions = gold_pitch
166
+ if gold_energy is not None:
167
+ energy_predictions = gold_energy
168
+ pitch_embeddings = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
169
+ energy_embeddings = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
170
+ encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
171
+ encoded_texts = self.length_regulator(encoded_texts, duration_predictions, alpha)
172
+ else:
173
+ duration_predictions = self.duration_predictor(encoded_texts, duration_masks)
174
+
175
+ # use groundtruth in training
176
+ pitch_embeddings = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2)
177
+ energy_embeddings = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2)
178
+ encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings
179
+ encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim)
180
+
181
+ # forward decoder
182
+ if speech_lens is not None and not is_inference:
183
+ if self.reduction_factor > 1:
184
+ olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens])
185
+ else:
186
+ olens_in = speech_lens
187
+ h_masks = self._source_mask(olens_in)
188
+ else:
189
+ h_masks = None
190
+ zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim)
191
+ before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
192
+
193
+ # postnet -> (B, Lmax//r * r, odim)
194
+ after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
195
+
196
+ return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions
197
+
198
+ @torch.no_grad()
199
+ def forward(self,
200
+ text,
201
+ speech=None,
202
+ durations=None,
203
+ pitch=None,
204
+ energy=None,
205
+ utterance_embedding=None,
206
+ return_duration_pitch_energy=False,
207
+ lang_id=None):
208
+ """
209
+ Generate the sequence of features given the sequences of characters.
210
+
211
+ Args:
212
+ text: Input sequence of characters
213
+ speech: Feature sequence to extract style
214
+ durations: Groundtruth of duration
215
+ pitch: Groundtruth of token-averaged pitch
216
+ energy: Groundtruth of token-averaged energy
217
+ return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
218
+ utterance_embedding: embedding of utterance wide parameters
219
+
220
+ Returns:
221
+ Mel Spectrogram
222
+
223
+ """
224
+ self.eval()
225
+ # setup batch axis
226
+ ilens = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device)
227
+ if speech is not None:
228
+ gold_speech = speech.unsqueeze(0)
229
+ else:
230
+ gold_speech = None
231
+ if durations is not None:
232
+ durations = durations.unsqueeze(0)
233
+ if pitch is not None:
234
+ pitch = pitch.unsqueeze(0)
235
+ if energy is not None:
236
+ energy = energy.unsqueeze(0)
237
+ if lang_id is not None:
238
+ lang_id = lang_id.unsqueeze(0)
239
+
240
+ before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(text.unsqueeze(0),
241
+ ilens,
242
+ gold_speech=gold_speech,
243
+ gold_durations=durations,
244
+ is_inference=True,
245
+ gold_pitch=pitch,
246
+ gold_energy=energy,
247
+ utterance_embedding=utterance_embedding.unsqueeze(0),
248
+ lang_ids=lang_id)
249
+ self.train()
250
+ if return_duration_pitch_energy:
251
+ return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0]
252
+ return after_outs[0]
253
+
254
+ def _source_mask(self, ilens):
255
+ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
256
+ return x_masks.unsqueeze(-2)