Spaces:
Runtime error
Runtime error
Merge branch 'main' into ps
Browse files- checkpoints/diffspeech/config.yaml +393 -0
- checkpoints/diffspeech/model_ckpt_steps_160000.ckpt +3 -0
- docs/diffspeech.md +62 -0
- docs/prepare_vocoder.md +1 -1
- egs/datasets/audio/lj/ds.yaml +29 -0
- egs/egs_bases/tts/ds.yaml +32 -0
- inference/tts/ds.py +30 -0
- modules/tts/commons/align_ops.py +2 -3
- modules/tts/diffspeech/net.py +110 -0
- modules/tts/diffspeech/shallow_diffusion_tts.py +281 -0
- tasks/tts/diffspeech.py +111 -0
checkpoints/diffspeech/config.yaml
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
K_step: 71
|
2 |
+
accumulate_grad_batches: 1
|
3 |
+
amp: false
|
4 |
+
audio_num_mel_bins: 80
|
5 |
+
audio_sample_rate: 22050
|
6 |
+
base_config:
|
7 |
+
- egs/egs_bases/tts/ds.yaml
|
8 |
+
- ./fs2_orig.yaml
|
9 |
+
binarization_args:
|
10 |
+
min_sil_duration: 0.1
|
11 |
+
shuffle: false
|
12 |
+
test_range:
|
13 |
+
- 0
|
14 |
+
- 523
|
15 |
+
train_range:
|
16 |
+
- 871
|
17 |
+
- -1
|
18 |
+
trim_eos_bos: false
|
19 |
+
valid_range:
|
20 |
+
- 523
|
21 |
+
- 871
|
22 |
+
with_align: true
|
23 |
+
with_f0: true
|
24 |
+
with_f0cwt: true
|
25 |
+
with_linear: false
|
26 |
+
with_spk_embed: false
|
27 |
+
with_wav: false
|
28 |
+
binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
|
29 |
+
binary_data_dir: data/binary/ljspeech_cwt
|
30 |
+
check_val_every_n_epoch: 10
|
31 |
+
clip_grad_norm: 1
|
32 |
+
clip_grad_value: 0
|
33 |
+
conv_use_pos: false
|
34 |
+
cwt_std_scale: 1.0
|
35 |
+
debug: false
|
36 |
+
dec_dilations:
|
37 |
+
- 1
|
38 |
+
- 1
|
39 |
+
- 1
|
40 |
+
- 1
|
41 |
+
dec_ffn_kernel_size: 9
|
42 |
+
dec_inp_add_noise: false
|
43 |
+
dec_kernel_size: 5
|
44 |
+
dec_layers: 4
|
45 |
+
dec_post_net_kernel: 3
|
46 |
+
decay_steps: 50000
|
47 |
+
decoder_rnn_dim: 0
|
48 |
+
decoder_type: fft
|
49 |
+
diff_decoder_type: wavenet
|
50 |
+
diff_loss_type: l1
|
51 |
+
dilation_cycle_length: 1
|
52 |
+
dropout: 0.0
|
53 |
+
ds_workers: 2
|
54 |
+
dur_predictor_kernel: 3
|
55 |
+
dur_predictor_layers: 2
|
56 |
+
enc_dec_norm: ln
|
57 |
+
enc_dilations:
|
58 |
+
- 1
|
59 |
+
- 1
|
60 |
+
- 1
|
61 |
+
- 1
|
62 |
+
enc_ffn_kernel_size: 9
|
63 |
+
enc_kernel_size: 5
|
64 |
+
enc_layers: 4
|
65 |
+
enc_post_net_kernel: 3
|
66 |
+
enc_pre_ln: true
|
67 |
+
enc_prenet: true
|
68 |
+
encoder_K: 8
|
69 |
+
encoder_type: fft
|
70 |
+
endless_ds: true
|
71 |
+
eval_max_batches: -1
|
72 |
+
f0_max: 600
|
73 |
+
f0_min: 80
|
74 |
+
ffn_act: gelu
|
75 |
+
ffn_hidden_size: 1024
|
76 |
+
fft_size: 1024
|
77 |
+
fmax: 7600
|
78 |
+
fmin: 80
|
79 |
+
frames_multiple: 1
|
80 |
+
fs2_ckpt: checkpoints/fs2_exp/model_ckpt_steps_160000.ckpt
|
81 |
+
gen_dir_name: ''
|
82 |
+
griffin_lim_iters: 30
|
83 |
+
hidden_size: 256
|
84 |
+
hop_size: 256
|
85 |
+
infer: false
|
86 |
+
keep_bins: 80
|
87 |
+
lambda_commit: 0.25
|
88 |
+
lambda_energy: 0.1
|
89 |
+
lambda_f0: 1.0
|
90 |
+
lambda_ph_dur: 0.1
|
91 |
+
lambda_sent_dur: 1.0
|
92 |
+
lambda_uv: 1.0
|
93 |
+
lambda_word_dur: 1.0
|
94 |
+
layers_in_block: 2
|
95 |
+
load_ckpt: ''
|
96 |
+
loud_norm: false
|
97 |
+
lr: 0.001
|
98 |
+
max_beta: 0.06
|
99 |
+
max_epochs: 1000
|
100 |
+
max_frames: 1548
|
101 |
+
max_input_tokens: 1550
|
102 |
+
max_sentences: 128
|
103 |
+
max_tokens: 30000
|
104 |
+
max_updates: 160000
|
105 |
+
max_valid_sentences: 1
|
106 |
+
max_valid_tokens: 60000
|
107 |
+
mel_losses: l1:0.5|ssim:0.5
|
108 |
+
mel_vmax: 1.5
|
109 |
+
mel_vmin: -6
|
110 |
+
min_frames: 0
|
111 |
+
num_ckpt_keep: 3
|
112 |
+
num_heads: 2
|
113 |
+
num_sanity_val_steps: 5
|
114 |
+
num_spk: 1
|
115 |
+
num_valid_plots: 10
|
116 |
+
optimizer_adam_beta1: 0.9
|
117 |
+
optimizer_adam_beta2: 0.98
|
118 |
+
out_wav_norm: false
|
119 |
+
pitch_extractor: parselmouth
|
120 |
+
pitch_key: pitch
|
121 |
+
pitch_type: cwt
|
122 |
+
predictor_dropout: 0.5
|
123 |
+
predictor_grad: 0.1
|
124 |
+
predictor_hidden: -1
|
125 |
+
predictor_kernel: 5
|
126 |
+
predictor_layers: 2
|
127 |
+
preprocess_args:
|
128 |
+
add_eos_bos: true
|
129 |
+
mfa_group_shuffle: false
|
130 |
+
mfa_offset: 0.02
|
131 |
+
nsample_per_mfa_group: 1000
|
132 |
+
reset_phone_dict: true
|
133 |
+
reset_word_dict: true
|
134 |
+
save_sil_mask: true
|
135 |
+
txt_processor: en
|
136 |
+
use_mfa: true
|
137 |
+
vad_max_silence_length: 12
|
138 |
+
wav_processors: []
|
139 |
+
with_phsep: true
|
140 |
+
preprocess_cls: egs.datasets.audio.lj.preprocess.LJPreprocess
|
141 |
+
print_nan_grads: false
|
142 |
+
processed_data_dir: data/processed/ljspeech
|
143 |
+
profile_infer: false
|
144 |
+
raw_data_dir: data/raw/LJSpeech-1.1
|
145 |
+
ref_norm_layer: bn
|
146 |
+
rename_tmux: true
|
147 |
+
residual_channels: 256
|
148 |
+
residual_layers: 20
|
149 |
+
resume_from_checkpoint: 0
|
150 |
+
save_best: false
|
151 |
+
save_codes:
|
152 |
+
- tasks
|
153 |
+
- modules
|
154 |
+
- egs
|
155 |
+
save_f0: false
|
156 |
+
save_gt: true
|
157 |
+
schedule_type: linear
|
158 |
+
scheduler: warmup
|
159 |
+
seed: 1234
|
160 |
+
sort_by_len: true
|
161 |
+
spec_max:
|
162 |
+
- -0.5982
|
163 |
+
- -0.0778
|
164 |
+
- 0.1205
|
165 |
+
- 0.2747
|
166 |
+
- 0.4657
|
167 |
+
- 0.5123
|
168 |
+
- 0.583
|
169 |
+
- 0.7093
|
170 |
+
- 0.6461
|
171 |
+
- 0.6101
|
172 |
+
- 0.7316
|
173 |
+
- 0.7715
|
174 |
+
- 0.7681
|
175 |
+
- 0.8349
|
176 |
+
- 0.7815
|
177 |
+
- 0.7591
|
178 |
+
- 0.791
|
179 |
+
- 0.7433
|
180 |
+
- 0.7352
|
181 |
+
- 0.6869
|
182 |
+
- 0.6854
|
183 |
+
- 0.6623
|
184 |
+
- 0.5353
|
185 |
+
- 0.6492
|
186 |
+
- 0.6909
|
187 |
+
- 0.6106
|
188 |
+
- 0.5761
|
189 |
+
- 0.5236
|
190 |
+
- 0.5638
|
191 |
+
- 0.4054
|
192 |
+
- 0.4545
|
193 |
+
- 0.3407
|
194 |
+
- 0.3037
|
195 |
+
- 0.338
|
196 |
+
- 0.1599
|
197 |
+
- 0.1603
|
198 |
+
- 0.2741
|
199 |
+
- 0.213
|
200 |
+
- 0.1569
|
201 |
+
- 0.1911
|
202 |
+
- 0.2324
|
203 |
+
- 0.1586
|
204 |
+
- 0.1221
|
205 |
+
- 0.0341
|
206 |
+
- -0.0558
|
207 |
+
- 0.0553
|
208 |
+
- -0.1153
|
209 |
+
- -0.0933
|
210 |
+
- -0.1171
|
211 |
+
- -0.005
|
212 |
+
- -0.1519
|
213 |
+
- -0.1629
|
214 |
+
- -0.0522
|
215 |
+
- -0.0739
|
216 |
+
- -0.2069
|
217 |
+
- -0.2405
|
218 |
+
- -0.1244
|
219 |
+
- -0.2582
|
220 |
+
- -0.1361
|
221 |
+
- -0.1575
|
222 |
+
- -0.1442
|
223 |
+
- 0.0513
|
224 |
+
- -0.1567
|
225 |
+
- -0.2
|
226 |
+
- 0.0086
|
227 |
+
- -0.0698
|
228 |
+
- 0.1385
|
229 |
+
- 0.0941
|
230 |
+
- 0.1864
|
231 |
+
- 0.1225
|
232 |
+
- 0.1389
|
233 |
+
- 0.1382
|
234 |
+
- 0.167
|
235 |
+
- 0.1007
|
236 |
+
- 0.1444
|
237 |
+
- 0.0888
|
238 |
+
- 0.1998
|
239 |
+
- 0.228
|
240 |
+
- 0.2932
|
241 |
+
- 0.3047
|
242 |
+
spec_min:
|
243 |
+
- -4.7574
|
244 |
+
- -4.6783
|
245 |
+
- -4.6431
|
246 |
+
- -4.5832
|
247 |
+
- -4.539
|
248 |
+
- -4.6771
|
249 |
+
- -4.8089
|
250 |
+
- -4.7672
|
251 |
+
- -4.5784
|
252 |
+
- -4.7755
|
253 |
+
- -4.715
|
254 |
+
- -4.8919
|
255 |
+
- -4.8271
|
256 |
+
- -4.7389
|
257 |
+
- -4.6047
|
258 |
+
- -4.7759
|
259 |
+
- -4.6799
|
260 |
+
- -4.8201
|
261 |
+
- -4.7823
|
262 |
+
- -4.8262
|
263 |
+
- -4.7857
|
264 |
+
- -4.7545
|
265 |
+
- -4.9358
|
266 |
+
- -4.9733
|
267 |
+
- -5.1134
|
268 |
+
- -5.1395
|
269 |
+
- -4.9016
|
270 |
+
- -4.8434
|
271 |
+
- -5.0189
|
272 |
+
- -4.846
|
273 |
+
- -5.0529
|
274 |
+
- -4.951
|
275 |
+
- -5.0217
|
276 |
+
- -5.0049
|
277 |
+
- -5.1831
|
278 |
+
- -5.1445
|
279 |
+
- -5.1015
|
280 |
+
- -5.0281
|
281 |
+
- -4.9887
|
282 |
+
- -4.9916
|
283 |
+
- -4.9785
|
284 |
+
- -4.9071
|
285 |
+
- -4.9488
|
286 |
+
- -5.0342
|
287 |
+
- -4.9332
|
288 |
+
- -5.065
|
289 |
+
- -4.8924
|
290 |
+
- -5.0875
|
291 |
+
- -5.0483
|
292 |
+
- -5.0848
|
293 |
+
- -5.0655
|
294 |
+
- -5.0279
|
295 |
+
- -5.0015
|
296 |
+
- -5.0792
|
297 |
+
- -5.0636
|
298 |
+
- -5.2413
|
299 |
+
- -5.1421
|
300 |
+
- -5.171
|
301 |
+
- -5.3256
|
302 |
+
- -5.0511
|
303 |
+
- -5.1186
|
304 |
+
- -5.0057
|
305 |
+
- -5.0446
|
306 |
+
- -5.1173
|
307 |
+
- -5.0325
|
308 |
+
- -5.1085
|
309 |
+
- -5.0053
|
310 |
+
- -5.0755
|
311 |
+
- -5.1176
|
312 |
+
- -5.1004
|
313 |
+
- -5.2153
|
314 |
+
- -5.2757
|
315 |
+
- -5.3025
|
316 |
+
- -5.2867
|
317 |
+
- -5.2918
|
318 |
+
- -5.3328
|
319 |
+
- -5.2731
|
320 |
+
- -5.2985
|
321 |
+
- -5.24
|
322 |
+
- -5.2211
|
323 |
+
task_cls: tasks.tts.diffspeech.DiffSpeechTask
|
324 |
+
tb_log_interval: 100
|
325 |
+
test_ids:
|
326 |
+
- 0
|
327 |
+
- 1
|
328 |
+
- 2
|
329 |
+
- 3
|
330 |
+
- 4
|
331 |
+
- 5
|
332 |
+
- 6
|
333 |
+
- 7
|
334 |
+
- 8
|
335 |
+
- 9
|
336 |
+
- 10
|
337 |
+
- 11
|
338 |
+
- 12
|
339 |
+
- 13
|
340 |
+
- 14
|
341 |
+
- 15
|
342 |
+
- 16
|
343 |
+
- 17
|
344 |
+
- 18
|
345 |
+
- 19
|
346 |
+
- 68
|
347 |
+
- 70
|
348 |
+
- 74
|
349 |
+
- 87
|
350 |
+
- 110
|
351 |
+
- 172
|
352 |
+
- 190
|
353 |
+
- 215
|
354 |
+
- 231
|
355 |
+
- 294
|
356 |
+
- 316
|
357 |
+
- 324
|
358 |
+
- 402
|
359 |
+
- 422
|
360 |
+
- 485
|
361 |
+
- 500
|
362 |
+
- 505
|
363 |
+
- 508
|
364 |
+
- 509
|
365 |
+
- 519
|
366 |
+
test_input_yaml: ''
|
367 |
+
test_num: 100
|
368 |
+
test_set_name: test
|
369 |
+
timesteps: 100
|
370 |
+
train_set_name: train
|
371 |
+
train_sets: ''
|
372 |
+
use_energy_embed: true
|
373 |
+
use_gt_dur: false
|
374 |
+
use_gt_energy: false
|
375 |
+
use_gt_f0: false
|
376 |
+
use_pitch_embed: true
|
377 |
+
use_pos_embed: true
|
378 |
+
use_spk_embed: false
|
379 |
+
use_spk_id: false
|
380 |
+
use_uv: true
|
381 |
+
use_word_input: false
|
382 |
+
val_check_interval: 2000
|
383 |
+
valid_infer_interval: 10000
|
384 |
+
valid_monitor_key: val_loss
|
385 |
+
valid_monitor_mode: min
|
386 |
+
valid_set_name: valid
|
387 |
+
vocoder: HifiGAN
|
388 |
+
vocoder_ckpt: checkpoints/hifi_lj
|
389 |
+
warmup_updates: 4000
|
390 |
+
weight_decay: 0
|
391 |
+
win_size: 1024
|
392 |
+
word_dict_size: 10000
|
393 |
+
work_dir: checkpoints/0209_ds_1
|
checkpoints/diffspeech/model_ckpt_steps_160000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:503f81009a75c02d868253b6fb4f1411aeaa32308b101d7804447bc583636b83
|
3 |
+
size 168816223
|
docs/diffspeech.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run DiffSpeech
|
2 |
+
|
3 |
+
## Quick Start
|
4 |
+
|
5 |
+
### Install Dependencies
|
6 |
+
|
7 |
+
Install dependencies following [readme.md](../readme.md)
|
8 |
+
|
9 |
+
### Set Config Path and Experiment Name
|
10 |
+
|
11 |
+
```bash
|
12 |
+
export CONFIG_NAME=egs/datasets/audio/lj/ds.yaml
|
13 |
+
export MY_EXP_NAME=ds_exp
|
14 |
+
```
|
15 |
+
|
16 |
+
### Preprocess and binary dataset
|
17 |
+
|
18 |
+
Prepare dataset following [prepare_data.md](./prepare_data.md)
|
19 |
+
|
20 |
+
### Prepare Vocoder
|
21 |
+
|
22 |
+
Prepare vocoder following [prepare_vocoder.md](./prepare_vocoder.md)
|
23 |
+
|
24 |
+
## Training
|
25 |
+
|
26 |
+
First, you need a pre-trained FastSpeech2 checkpoint `chckpoints/fs2_exp/model_ckpt_steps_160000.ckpt`. To train a FastSpeech 2 model, run:
|
27 |
+
|
28 |
+
```bash
|
29 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config egs/datasets/audio/lj/fs2_orig.yaml --exp_name fs2_exp --reset
|
30 |
+
```
|
31 |
+
|
32 |
+
Then, run:
|
33 |
+
|
34 |
+
```bash
|
35 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config $CONFIG_NAME --exp_name $MY_EXP_NAME --reset
|
36 |
+
```
|
37 |
+
|
38 |
+
You can check the training and validation curves open Tensorboard via:
|
39 |
+
|
40 |
+
```bash
|
41 |
+
tensorboard --logdir checkpoints/$MY_EXP_NAME
|
42 |
+
```
|
43 |
+
|
44 |
+
## Inference (Testing)
|
45 |
+
|
46 |
+
```bash
|
47 |
+
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config $CONFIG_NAME --exp_name $MY_EXP_NAME --infer
|
48 |
+
```
|
49 |
+
|
50 |
+
## Citation
|
51 |
+
|
52 |
+
If you find this useful for your research, please use the following.
|
53 |
+
|
54 |
+
```bib
|
55 |
+
@article{liu2021diffsinger,
|
56 |
+
title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
|
57 |
+
author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Liu, Peng and Zhao, Zhou},
|
58 |
+
journal={arXiv preprint arXiv:2105.02446},
|
59 |
+
volume={2},
|
60 |
+
year={2021}
|
61 |
+
}
|
62 |
+
```
|
docs/prepare_vocoder.md
CHANGED
@@ -26,7 +26,7 @@ export MY_EXP_NAME=my_hifigan_exp
|
|
26 |
Prepare dataset following [prepare_data.md](./prepare_data.md).
|
27 |
|
28 |
If you have run the `prepare_data` step of the acoustic
|
29 |
-
model (e.g.,
|
30 |
|
31 |
```bash
|
32 |
python data_gen/tts/runs/binarize.py --config $CONFIG_NAME
|
|
|
26 |
Prepare dataset following [prepare_data.md](./prepare_data.md).
|
27 |
|
28 |
If you have run the `prepare_data` step of the acoustic
|
29 |
+
model (e.g., PortaSpeech and DiffSpeech), you only need to binarize the dataset for the vocoder training:
|
30 |
|
31 |
```bash
|
32 |
python data_gen/tts/runs/binarize.py --config $CONFIG_NAME
|
egs/datasets/audio/lj/ds.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- egs/egs_bases/tts/ds.yaml
|
3 |
+
- ./fs2_orig.yaml
|
4 |
+
|
5 |
+
fs2_ckpt: checkpoints/fs2_exp/model_ckpt_steps_160000.ckpt
|
6 |
+
|
7 |
+
# spec_min and spec_max are calculated on the training set.
|
8 |
+
spec_min: [ -4.7574, -4.6783, -4.6431, -4.5832, -4.5390, -4.6771, -4.8089, -4.7672,
|
9 |
+
-4.5784, -4.7755, -4.7150, -4.8919, -4.8271, -4.7389, -4.6047, -4.7759,
|
10 |
+
-4.6799, -4.8201, -4.7823, -4.8262, -4.7857, -4.7545, -4.9358, -4.9733,
|
11 |
+
-5.1134, -5.1395, -4.9016, -4.8434, -5.0189, -4.8460, -5.0529, -4.9510,
|
12 |
+
-5.0217, -5.0049, -5.1831, -5.1445, -5.1015, -5.0281, -4.9887, -4.9916,
|
13 |
+
-4.9785, -4.9071, -4.9488, -5.0342, -4.9332, -5.0650, -4.8924, -5.0875,
|
14 |
+
-5.0483, -5.0848, -5.0655, -5.0279, -5.0015, -5.0792, -5.0636, -5.2413,
|
15 |
+
-5.1421, -5.1710, -5.3256, -5.0511, -5.1186, -5.0057, -5.0446, -5.1173,
|
16 |
+
-5.0325, -5.1085, -5.0053, -5.0755, -5.1176, -5.1004, -5.2153, -5.2757,
|
17 |
+
-5.3025, -5.2867, -5.2918, -5.3328, -5.2731, -5.2985, -5.2400, -5.2211 ]
|
18 |
+
spec_max: [ -0.5982, -0.0778, 0.1205, 0.2747, 0.4657, 0.5123, 0.5830, 0.7093,
|
19 |
+
0.6461, 0.6101, 0.7316, 0.7715, 0.7681, 0.8349, 0.7815, 0.7591,
|
20 |
+
0.7910, 0.7433, 0.7352, 0.6869, 0.6854, 0.6623, 0.5353, 0.6492,
|
21 |
+
0.6909, 0.6106, 0.5761, 0.5236, 0.5638, 0.4054, 0.4545, 0.3407,
|
22 |
+
0.3037, 0.3380, 0.1599, 0.1603, 0.2741, 0.2130, 0.1569, 0.1911,
|
23 |
+
0.2324, 0.1586, 0.1221, 0.0341, -0.0558, 0.0553, -0.1153, -0.0933,
|
24 |
+
-0.1171, -0.0050, -0.1519, -0.1629, -0.0522, -0.0739, -0.2069, -0.2405,
|
25 |
+
-0.1244, -0.2582, -0.1361, -0.1575, -0.1442, 0.0513, -0.1567, -0.2000,
|
26 |
+
0.0086, -0.0698, 0.1385, 0.0941, 0.1864, 0.1225, 0.1389, 0.1382,
|
27 |
+
0.1670, 0.1007, 0.1444, 0.0888, 0.1998, 0.2280, 0.2932, 0.3047 ]
|
28 |
+
|
29 |
+
max_tokens: 30000
|
egs/egs_bases/tts/ds.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: ./fs2_orig.yaml
|
2 |
+
|
3 |
+
# special configs for diffspeech
|
4 |
+
task_cls: tasks.tts.diffspeech.DiffSpeechTask
|
5 |
+
lr: 0.001
|
6 |
+
timesteps: 100
|
7 |
+
K_step: 71
|
8 |
+
diff_loss_type: l1
|
9 |
+
diff_decoder_type: 'wavenet'
|
10 |
+
schedule_type: 'linear'
|
11 |
+
max_beta: 0.06
|
12 |
+
|
13 |
+
## model configs for diffspeech
|
14 |
+
dilation_cycle_length: 1
|
15 |
+
residual_layers: 20
|
16 |
+
residual_channels: 256
|
17 |
+
decay_steps: 50000
|
18 |
+
keep_bins: 80
|
19 |
+
#content_cond_steps: [ ] # [ 0, 10000 ]
|
20 |
+
#spk_cond_steps: [ ] # [ 0, 10000 ]
|
21 |
+
#gen_tgt_spk_id: -1
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
# training configs for diffspeech
|
26 |
+
#max_sentences: 48
|
27 |
+
#num_sanity_val_steps: 1
|
28 |
+
num_valid_plots: 10
|
29 |
+
use_gt_dur: false
|
30 |
+
use_gt_f0: false
|
31 |
+
#pitch_type: cwt
|
32 |
+
max_updates: 160000
|
inference/tts/ds.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# from inference.tts.fs import FastSpeechInfer
|
3 |
+
# from modules.tts.fs2_orig import FastSpeech2Orig
|
4 |
+
from inference.tts.base_tts_infer import BaseTTSInfer
|
5 |
+
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
|
6 |
+
from utils.commons.ckpt_utils import load_ckpt
|
7 |
+
from utils.commons.hparams import hparams
|
8 |
+
|
9 |
+
|
10 |
+
class DiffSpeechInfer(BaseTTSInfer):
|
11 |
+
def build_model(self):
|
12 |
+
dict_size = len(self.ph_encoder)
|
13 |
+
model = GaussianDiffusion(dict_size, self.hparams)
|
14 |
+
model.eval()
|
15 |
+
load_ckpt(model, hparams['work_dir'], 'model')
|
16 |
+
return model
|
17 |
+
|
18 |
+
def forward_model(self, inp):
|
19 |
+
sample = self.input_to_batch(inp)
|
20 |
+
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
21 |
+
spk_id = sample.get('spk_ids')
|
22 |
+
with torch.no_grad():
|
23 |
+
output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True)
|
24 |
+
mel_out = output['mel_out']
|
25 |
+
wav_out = self.run_vocoder(mel_out)
|
26 |
+
wav_out = wav_out.cpu().numpy()
|
27 |
+
return wav_out[0]
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
DiffSpeechInfer.example_run()
|
modules/tts/commons/align_ops.py
CHANGED
@@ -13,9 +13,8 @@ def mel2ph_to_mel2word(mel2ph, ph2word):
|
|
13 |
|
14 |
|
15 |
def clip_mel2token_to_multiple(mel2token, frames_multiple):
|
16 |
-
|
17 |
-
|
18 |
-
mel2token = mel2token[:, :max_frames]
|
19 |
return mel2token
|
20 |
|
21 |
|
|
|
13 |
|
14 |
|
15 |
def clip_mel2token_to_multiple(mel2token, frames_multiple):
|
16 |
+
max_frames = mel2token.shape[1] // frames_multiple * frames_multiple
|
17 |
+
mel2token = mel2token[:, :max_frames]
|
|
|
18 |
return mel2token
|
19 |
|
20 |
|
modules/tts/diffspeech/net.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from math import sqrt
|
8 |
+
|
9 |
+
Linear = nn.Linear
|
10 |
+
ConvTranspose2d = nn.ConvTranspose2d
|
11 |
+
|
12 |
+
|
13 |
+
class Mish(nn.Module):
|
14 |
+
def forward(self, x):
|
15 |
+
return x * torch.tanh(F.softplus(x))
|
16 |
+
|
17 |
+
|
18 |
+
class SinusoidalPosEmb(nn.Module):
|
19 |
+
def __init__(self, dim):
|
20 |
+
super().__init__()
|
21 |
+
self.dim = dim
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
device = x.device
|
25 |
+
half_dim = self.dim // 2
|
26 |
+
emb = math.log(10000) / (half_dim - 1)
|
27 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
28 |
+
emb = x[:, None] * emb[None, :]
|
29 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
30 |
+
return emb
|
31 |
+
|
32 |
+
|
33 |
+
def Conv1d(*args, **kwargs):
|
34 |
+
layer = nn.Conv1d(*args, **kwargs)
|
35 |
+
nn.init.kaiming_normal_(layer.weight)
|
36 |
+
return layer
|
37 |
+
|
38 |
+
|
39 |
+
class ResidualBlock(nn.Module):
|
40 |
+
def __init__(self, encoder_hidden, residual_channels, dilation):
|
41 |
+
super().__init__()
|
42 |
+
self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
|
43 |
+
self.diffusion_projection = Linear(residual_channels, residual_channels)
|
44 |
+
self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
45 |
+
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
46 |
+
|
47 |
+
def forward(self, x, conditioner, diffusion_step):
|
48 |
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
49 |
+
conditioner = self.conditioner_projection(conditioner)
|
50 |
+
y = x + diffusion_step
|
51 |
+
|
52 |
+
y = self.dilated_conv(y) + conditioner
|
53 |
+
|
54 |
+
gate, filter = torch.chunk(y, 2, dim=1)
|
55 |
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
56 |
+
|
57 |
+
y = self.output_projection(y)
|
58 |
+
residual, skip = torch.chunk(y, 2, dim=1)
|
59 |
+
return (x + residual) / sqrt(2.0), skip
|
60 |
+
|
61 |
+
|
62 |
+
class DiffNet(nn.Module):
|
63 |
+
def __init__(self, hparams):
|
64 |
+
super().__init__()
|
65 |
+
in_dims = hparams['audio_num_mel_bins']
|
66 |
+
self.encoder_hidden = hparams['hidden_size']
|
67 |
+
self.residual_layers = hparams['residual_layers']
|
68 |
+
self.residual_channels = hparams['residual_channels']
|
69 |
+
self.dilation_cycle_length = hparams['dilation_cycle_length']
|
70 |
+
|
71 |
+
self.input_projection = Conv1d(in_dims, self.residual_channels, 1)
|
72 |
+
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
|
73 |
+
dim = self.residual_channels
|
74 |
+
self.mlp = nn.Sequential(
|
75 |
+
nn.Linear(dim, dim * 4),
|
76 |
+
Mish(),
|
77 |
+
nn.Linear(dim * 4, dim)
|
78 |
+
)
|
79 |
+
self.residual_layers = nn.ModuleList([
|
80 |
+
ResidualBlock(self.encoder_hidden, self.residual_channels, 2 ** (i % self.dilation_cycle_length))
|
81 |
+
for i in range(self.residual_layers)
|
82 |
+
])
|
83 |
+
self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1)
|
84 |
+
self.output_projection = Conv1d(self.residual_channels, in_dims, 1)
|
85 |
+
nn.init.zeros_(self.output_projection.weight)
|
86 |
+
|
87 |
+
def forward(self, spec, diffusion_step, cond):
|
88 |
+
"""
|
89 |
+
|
90 |
+
:param spec: [B, 1, M, T]
|
91 |
+
:param diffusion_step: [B, 1]
|
92 |
+
:param cond: [B, M, T]
|
93 |
+
:return:
|
94 |
+
"""
|
95 |
+
x = spec[:, 0]
|
96 |
+
x = self.input_projection(x) # x [B, residual_channel, T]
|
97 |
+
|
98 |
+
x = F.relu(x)
|
99 |
+
diffusion_step = self.diffusion_embedding(diffusion_step)
|
100 |
+
diffusion_step = self.mlp(diffusion_step)
|
101 |
+
skip = []
|
102 |
+
for layer_id, layer in enumerate(self.residual_layers):
|
103 |
+
x, skip_connection = layer(x, cond, diffusion_step)
|
104 |
+
skip.append(skip_connection)
|
105 |
+
|
106 |
+
x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
|
107 |
+
x = self.skip_projection(x)
|
108 |
+
x = F.relu(x)
|
109 |
+
x = self.output_projection(x) # [B, 80, T]
|
110 |
+
return x[:, None, :, :]
|
modules/tts/diffspeech/shallow_diffusion_tts.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from functools import partial
|
4 |
+
from inspect import isfunction
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from modules.tts.fs2_orig import FastSpeech2Orig
|
12 |
+
from modules.tts.diffspeech.net import DiffNet
|
13 |
+
from modules.tts.commons.align_ops import expand_states
|
14 |
+
|
15 |
+
|
16 |
+
def exists(x):
|
17 |
+
return x is not None
|
18 |
+
|
19 |
+
|
20 |
+
def default(val, d):
|
21 |
+
if exists(val):
|
22 |
+
return val
|
23 |
+
return d() if isfunction(d) else d
|
24 |
+
|
25 |
+
|
26 |
+
# gaussian diffusion trainer class
|
27 |
+
|
28 |
+
def extract(a, t, x_shape):
|
29 |
+
b, *_ = t.shape
|
30 |
+
out = a.gather(-1, t)
|
31 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
32 |
+
|
33 |
+
|
34 |
+
def noise_like(shape, device, repeat=False):
|
35 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
36 |
+
noise = lambda: torch.randn(shape, device=device)
|
37 |
+
return repeat_noise() if repeat else noise()
|
38 |
+
|
39 |
+
|
40 |
+
def linear_beta_schedule(timesteps, max_beta=0.01):
|
41 |
+
"""
|
42 |
+
linear schedule
|
43 |
+
"""
|
44 |
+
betas = np.linspace(1e-4, max_beta, timesteps)
|
45 |
+
return betas
|
46 |
+
|
47 |
+
|
48 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
49 |
+
"""
|
50 |
+
cosine schedule
|
51 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
52 |
+
"""
|
53 |
+
steps = timesteps + 1
|
54 |
+
x = np.linspace(0, steps, steps)
|
55 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
56 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
57 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
58 |
+
return np.clip(betas, a_min=0, a_max=0.999)
|
59 |
+
|
60 |
+
|
61 |
+
beta_schedule = {
|
62 |
+
"cosine": cosine_beta_schedule,
|
63 |
+
"linear": linear_beta_schedule,
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
DIFF_DECODERS = {
|
68 |
+
'wavenet': lambda hp: DiffNet(hp),
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
class AuxModel(FastSpeech2Orig):
|
73 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
|
74 |
+
f0=None, uv=None, energy=None, infer=False, **kwargs):
|
75 |
+
ret = {}
|
76 |
+
encoder_out = self.encoder(txt_tokens) # [B, T, C]
|
77 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
78 |
+
style_embed = self.forward_style_embed(spk_embed, spk_id)
|
79 |
+
|
80 |
+
# add dur
|
81 |
+
dur_inp = (encoder_out + style_embed) * src_nonpadding
|
82 |
+
mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret)
|
83 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
84 |
+
decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph)
|
85 |
+
|
86 |
+
# add pitch and energy embed
|
87 |
+
if self.hparams['use_pitch_embed']:
|
88 |
+
pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
|
89 |
+
decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out)
|
90 |
+
|
91 |
+
# add pitch and energy embed
|
92 |
+
if self.hparams['use_energy_embed']:
|
93 |
+
energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
|
94 |
+
decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret)
|
95 |
+
|
96 |
+
# decoder input
|
97 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding
|
98 |
+
if self.hparams['dec_inp_add_noise']:
|
99 |
+
B, T, _ = decoder_inp.shape
|
100 |
+
z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device)
|
101 |
+
ret['adv_z'] = z
|
102 |
+
decoder_inp = torch.cat([decoder_inp, z], -1)
|
103 |
+
decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding
|
104 |
+
if kwargs['skip_decoder']:
|
105 |
+
return ret
|
106 |
+
ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
107 |
+
return ret
|
108 |
+
|
109 |
+
|
110 |
+
class GaussianDiffusion(nn.Module):
|
111 |
+
def __init__(self, dict_size, hparams, out_dims=None):
|
112 |
+
super().__init__()
|
113 |
+
self.hparams = hparams
|
114 |
+
out_dims = hparams['audio_num_mel_bins']
|
115 |
+
denoise_fn = DIFF_DECODERS[hparams['diff_decoder_type']](hparams)
|
116 |
+
timesteps = hparams['timesteps']
|
117 |
+
K_step = hparams['K_step']
|
118 |
+
loss_type = hparams['diff_loss_type']
|
119 |
+
spec_min = hparams['spec_min']
|
120 |
+
spec_max = hparams['spec_max']
|
121 |
+
|
122 |
+
self.denoise_fn = denoise_fn
|
123 |
+
self.fs2 = AuxModel(dict_size, hparams)
|
124 |
+
self.mel_bins = out_dims
|
125 |
+
|
126 |
+
if hparams['schedule_type'] == 'linear':
|
127 |
+
betas = linear_beta_schedule(timesteps, hparams['max_beta'])
|
128 |
+
else:
|
129 |
+
betas = cosine_beta_schedule(timesteps)
|
130 |
+
|
131 |
+
alphas = 1. - betas
|
132 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
133 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
134 |
+
|
135 |
+
timesteps, = betas.shape
|
136 |
+
self.num_timesteps = int(timesteps)
|
137 |
+
self.K_step = K_step
|
138 |
+
self.loss_type = loss_type
|
139 |
+
|
140 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
141 |
+
|
142 |
+
self.register_buffer('betas', to_torch(betas))
|
143 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
144 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
145 |
+
|
146 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
147 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
148 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
149 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
150 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
151 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
152 |
+
|
153 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
154 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
155 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
156 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
157 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
158 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
159 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
160 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
161 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
162 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
163 |
+
|
164 |
+
self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
|
165 |
+
self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
|
166 |
+
|
167 |
+
def q_mean_variance(self, x_start, t):
|
168 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
169 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
170 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
171 |
+
return mean, variance, log_variance
|
172 |
+
|
173 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
174 |
+
return (
|
175 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
176 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
177 |
+
)
|
178 |
+
|
179 |
+
def q_posterior(self, x_start, x_t, t):
|
180 |
+
posterior_mean = (
|
181 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
182 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
183 |
+
)
|
184 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
185 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
186 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
187 |
+
|
188 |
+
def p_mean_variance(self, x, t, cond, clip_denoised: bool):
|
189 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
190 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
191 |
+
|
192 |
+
if clip_denoised:
|
193 |
+
x_recon.clamp_(-1., 1.)
|
194 |
+
|
195 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
196 |
+
return model_mean, posterior_variance, posterior_log_variance
|
197 |
+
|
198 |
+
@torch.no_grad()
|
199 |
+
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
200 |
+
b, *_, device = *x.shape, x.device
|
201 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
202 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
203 |
+
# no noise when t == 0
|
204 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
205 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
206 |
+
|
207 |
+
def q_sample(self, x_start, t, noise=None):
|
208 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
209 |
+
return (
|
210 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
211 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
212 |
+
)
|
213 |
+
|
214 |
+
def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
|
215 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
216 |
+
|
217 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
218 |
+
x_recon = self.denoise_fn(x_noisy, t, cond)
|
219 |
+
|
220 |
+
if self.loss_type == 'l1':
|
221 |
+
if nonpadding is not None:
|
222 |
+
loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
|
223 |
+
else:
|
224 |
+
# print('are you sure w/o nonpadding?')
|
225 |
+
loss = (noise - x_recon).abs().mean()
|
226 |
+
|
227 |
+
elif self.loss_type == 'l2':
|
228 |
+
loss = F.mse_loss(noise, x_recon)
|
229 |
+
else:
|
230 |
+
raise NotImplementedError()
|
231 |
+
|
232 |
+
return loss
|
233 |
+
|
234 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
|
235 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
236 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
237 |
+
ret = self.fs2(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
|
238 |
+
f0=f0, uv=uv, energy=energy, infer=infer, skip_decoder=(not infer), **kwargs)
|
239 |
+
# (txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
240 |
+
# skip_decoder=(not infer), infer=infer, **kwargs)
|
241 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
242 |
+
|
243 |
+
if not infer:
|
244 |
+
t = torch.randint(0, self.K_step, (b,), device=device).long()
|
245 |
+
x = ref_mels
|
246 |
+
x = self.norm_spec(x)
|
247 |
+
x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
248 |
+
ret['diff_loss'] = self.p_losses(x, t, cond)
|
249 |
+
# nonpadding = (mel2ph != 0).float()
|
250 |
+
# ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
|
251 |
+
ret['mel_out'] = None
|
252 |
+
else:
|
253 |
+
ret['fs2_mel'] = ret['mel_out']
|
254 |
+
fs2_mels = ret['mel_out']
|
255 |
+
t = self.K_step
|
256 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
257 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
258 |
+
|
259 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
260 |
+
if self.hparams.get('gaussian_start') is not None and self.hparams['gaussian_start']:
|
261 |
+
print('===> gaussian start.')
|
262 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
263 |
+
x = torch.randn(shape, device=device)
|
264 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
265 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
266 |
+
x = x[:, 0].transpose(1, 2)
|
267 |
+
ret['mel_out'] = self.denorm_spec(x)
|
268 |
+
|
269 |
+
return ret
|
270 |
+
|
271 |
+
def norm_spec(self, x):
|
272 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
273 |
+
|
274 |
+
def denorm_spec(self, x):
|
275 |
+
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
276 |
+
|
277 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
278 |
+
return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
279 |
+
|
280 |
+
def out2mel(self, x):
|
281 |
+
return x
|
tasks/tts/diffspeech.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
|
4 |
+
from tasks.tts.fs2_orig import FastSpeech2OrigTask
|
5 |
+
|
6 |
+
import utils
|
7 |
+
from utils.commons.hparams import hparams
|
8 |
+
from utils.commons.ckpt_utils import load_ckpt
|
9 |
+
from utils.audio.pitch.utils import denorm_f0
|
10 |
+
|
11 |
+
|
12 |
+
class DiffSpeechTask(FastSpeech2OrigTask):
|
13 |
+
def build_tts_model(self):
|
14 |
+
# get min and max
|
15 |
+
# import torch
|
16 |
+
# from tqdm import tqdm
|
17 |
+
# v_min = torch.ones([80]) * 100
|
18 |
+
# v_max = torch.ones([80]) * -100
|
19 |
+
# for i, ds in enumerate(tqdm(self.dataset_cls('train'))):
|
20 |
+
# v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max)
|
21 |
+
# v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min)
|
22 |
+
# if i % 100 == 0:
|
23 |
+
# print(i, v_min, v_max)
|
24 |
+
# print('final', v_min, v_max)
|
25 |
+
dict_size = len(self.token_encoder)
|
26 |
+
self.model = GaussianDiffusion(dict_size, hparams)
|
27 |
+
if hparams['fs2_ckpt'] != '':
|
28 |
+
load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
|
29 |
+
for k, v in self.model.fs2.named_parameters():
|
30 |
+
if 'predictor' not in k:
|
31 |
+
v.requires_grad = False
|
32 |
+
# or
|
33 |
+
# for k, v in self.model.fs2.named_parameters():
|
34 |
+
# v.requires_grad = False
|
35 |
+
|
36 |
+
def build_optimizer(self, model):
|
37 |
+
self.optimizer = optimizer = torch.optim.AdamW(
|
38 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
39 |
+
lr=hparams['lr'],
|
40 |
+
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
|
41 |
+
weight_decay=hparams['weight_decay'])
|
42 |
+
return optimizer
|
43 |
+
|
44 |
+
def build_scheduler(self, optimizer):
|
45 |
+
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
|
46 |
+
|
47 |
+
def run_model(self, sample, infer=False, *args, **kwargs):
|
48 |
+
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
49 |
+
spk_embed = sample.get('spk_embed')
|
50 |
+
spk_id = sample.get('spk_ids')
|
51 |
+
if not infer:
|
52 |
+
target = sample['mels'] # [B, T_s, 80]
|
53 |
+
mel2ph = sample['mel2ph'] # [B, T_s]
|
54 |
+
f0 = sample.get('f0')
|
55 |
+
uv = sample.get('uv')
|
56 |
+
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
|
57 |
+
ref_mels=target, f0=f0, uv=uv, infer=False)
|
58 |
+
losses = {}
|
59 |
+
if 'diff_loss' in output:
|
60 |
+
losses['mel'] = output['diff_loss']
|
61 |
+
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
|
62 |
+
if hparams['use_pitch_embed']:
|
63 |
+
self.add_pitch_loss(output, sample, losses)
|
64 |
+
return losses, output
|
65 |
+
else:
|
66 |
+
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
|
67 |
+
use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0'])
|
68 |
+
mel2ph, uv, f0 = None, None, None
|
69 |
+
if use_gt_dur:
|
70 |
+
mel2ph = sample['mel2ph']
|
71 |
+
if use_gt_f0:
|
72 |
+
f0 = sample['f0']
|
73 |
+
uv = sample['uv']
|
74 |
+
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
|
75 |
+
ref_mels=None, f0=f0, uv=uv, infer=True)
|
76 |
+
return output
|
77 |
+
|
78 |
+
def save_valid_result(self, sample, batch_idx, model_out):
|
79 |
+
sr = hparams['audio_sample_rate']
|
80 |
+
f0_gt = None
|
81 |
+
# mel_out = model_out['mel_out']
|
82 |
+
if sample.get('f0') is not None:
|
83 |
+
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu())
|
84 |
+
# self.plot_mel(batch_idx, sample['mels'], mel_out, f0s=f0_gt)
|
85 |
+
if self.global_step > 0:
|
86 |
+
# wav_pred = self.vocoder.spec2wav(mel_out[0].cpu(), f0=f0_gt)
|
87 |
+
# self.logger.add_audio(f'wav_val_{batch_idx}', wav_pred, self.global_step, sr)
|
88 |
+
# with gt duration
|
89 |
+
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True)
|
90 |
+
dur_info = self.get_plot_dur_info(sample, model_out)
|
91 |
+
del dur_info['dur_pred']
|
92 |
+
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
|
93 |
+
self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr)
|
94 |
+
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'diffmel_gdur_{batch_idx}',
|
95 |
+
dur_info=dur_info, f0s=f0_gt)
|
96 |
+
self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'][0], f'fs2mel_gdur_{batch_idx}',
|
97 |
+
dur_info=dur_info, f0s=f0_gt) # gt mel vs. fs2 mel
|
98 |
+
|
99 |
+
# with pred duration
|
100 |
+
if not hparams['use_gt_dur']:
|
101 |
+
model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False)
|
102 |
+
dur_info = self.get_plot_dur_info(sample, model_out)
|
103 |
+
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'mel_pdur_{batch_idx}',
|
104 |
+
dur_info=dur_info, f0s=f0_gt)
|
105 |
+
wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
|
106 |
+
self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr)
|
107 |
+
# gt wav
|
108 |
+
if self.global_step <= hparams['valid_infer_interval']:
|
109 |
+
mel_gt = sample['mels'][0].cpu()
|
110 |
+
wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt)
|
111 |
+
self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr)
|