RayeRen commited on
Commit
7ebce0b
2 Parent(s): 1014fe4 bd1afa3

Merge branch 'main' into ps

Browse files
checkpoints/diffspeech/config.yaml ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ K_step: 71
2
+ accumulate_grad_batches: 1
3
+ amp: false
4
+ audio_num_mel_bins: 80
5
+ audio_sample_rate: 22050
6
+ base_config:
7
+ - egs/egs_bases/tts/ds.yaml
8
+ - ./fs2_orig.yaml
9
+ binarization_args:
10
+ min_sil_duration: 0.1
11
+ shuffle: false
12
+ test_range:
13
+ - 0
14
+ - 523
15
+ train_range:
16
+ - 871
17
+ - -1
18
+ trim_eos_bos: false
19
+ valid_range:
20
+ - 523
21
+ - 871
22
+ with_align: true
23
+ with_f0: true
24
+ with_f0cwt: true
25
+ with_linear: false
26
+ with_spk_embed: false
27
+ with_wav: false
28
+ binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
29
+ binary_data_dir: data/binary/ljspeech_cwt
30
+ check_val_every_n_epoch: 10
31
+ clip_grad_norm: 1
32
+ clip_grad_value: 0
33
+ conv_use_pos: false
34
+ cwt_std_scale: 1.0
35
+ debug: false
36
+ dec_dilations:
37
+ - 1
38
+ - 1
39
+ - 1
40
+ - 1
41
+ dec_ffn_kernel_size: 9
42
+ dec_inp_add_noise: false
43
+ dec_kernel_size: 5
44
+ dec_layers: 4
45
+ dec_post_net_kernel: 3
46
+ decay_steps: 50000
47
+ decoder_rnn_dim: 0
48
+ decoder_type: fft
49
+ diff_decoder_type: wavenet
50
+ diff_loss_type: l1
51
+ dilation_cycle_length: 1
52
+ dropout: 0.0
53
+ ds_workers: 2
54
+ dur_predictor_kernel: 3
55
+ dur_predictor_layers: 2
56
+ enc_dec_norm: ln
57
+ enc_dilations:
58
+ - 1
59
+ - 1
60
+ - 1
61
+ - 1
62
+ enc_ffn_kernel_size: 9
63
+ enc_kernel_size: 5
64
+ enc_layers: 4
65
+ enc_post_net_kernel: 3
66
+ enc_pre_ln: true
67
+ enc_prenet: true
68
+ encoder_K: 8
69
+ encoder_type: fft
70
+ endless_ds: true
71
+ eval_max_batches: -1
72
+ f0_max: 600
73
+ f0_min: 80
74
+ ffn_act: gelu
75
+ ffn_hidden_size: 1024
76
+ fft_size: 1024
77
+ fmax: 7600
78
+ fmin: 80
79
+ frames_multiple: 1
80
+ fs2_ckpt: checkpoints/fs2_exp/model_ckpt_steps_160000.ckpt
81
+ gen_dir_name: ''
82
+ griffin_lim_iters: 30
83
+ hidden_size: 256
84
+ hop_size: 256
85
+ infer: false
86
+ keep_bins: 80
87
+ lambda_commit: 0.25
88
+ lambda_energy: 0.1
89
+ lambda_f0: 1.0
90
+ lambda_ph_dur: 0.1
91
+ lambda_sent_dur: 1.0
92
+ lambda_uv: 1.0
93
+ lambda_word_dur: 1.0
94
+ layers_in_block: 2
95
+ load_ckpt: ''
96
+ loud_norm: false
97
+ lr: 0.001
98
+ max_beta: 0.06
99
+ max_epochs: 1000
100
+ max_frames: 1548
101
+ max_input_tokens: 1550
102
+ max_sentences: 128
103
+ max_tokens: 30000
104
+ max_updates: 160000
105
+ max_valid_sentences: 1
106
+ max_valid_tokens: 60000
107
+ mel_losses: l1:0.5|ssim:0.5
108
+ mel_vmax: 1.5
109
+ mel_vmin: -6
110
+ min_frames: 0
111
+ num_ckpt_keep: 3
112
+ num_heads: 2
113
+ num_sanity_val_steps: 5
114
+ num_spk: 1
115
+ num_valid_plots: 10
116
+ optimizer_adam_beta1: 0.9
117
+ optimizer_adam_beta2: 0.98
118
+ out_wav_norm: false
119
+ pitch_extractor: parselmouth
120
+ pitch_key: pitch
121
+ pitch_type: cwt
122
+ predictor_dropout: 0.5
123
+ predictor_grad: 0.1
124
+ predictor_hidden: -1
125
+ predictor_kernel: 5
126
+ predictor_layers: 2
127
+ preprocess_args:
128
+ add_eos_bos: true
129
+ mfa_group_shuffle: false
130
+ mfa_offset: 0.02
131
+ nsample_per_mfa_group: 1000
132
+ reset_phone_dict: true
133
+ reset_word_dict: true
134
+ save_sil_mask: true
135
+ txt_processor: en
136
+ use_mfa: true
137
+ vad_max_silence_length: 12
138
+ wav_processors: []
139
+ with_phsep: true
140
+ preprocess_cls: egs.datasets.audio.lj.preprocess.LJPreprocess
141
+ print_nan_grads: false
142
+ processed_data_dir: data/processed/ljspeech
143
+ profile_infer: false
144
+ raw_data_dir: data/raw/LJSpeech-1.1
145
+ ref_norm_layer: bn
146
+ rename_tmux: true
147
+ residual_channels: 256
148
+ residual_layers: 20
149
+ resume_from_checkpoint: 0
150
+ save_best: false
151
+ save_codes:
152
+ - tasks
153
+ - modules
154
+ - egs
155
+ save_f0: false
156
+ save_gt: true
157
+ schedule_type: linear
158
+ scheduler: warmup
159
+ seed: 1234
160
+ sort_by_len: true
161
+ spec_max:
162
+ - -0.5982
163
+ - -0.0778
164
+ - 0.1205
165
+ - 0.2747
166
+ - 0.4657
167
+ - 0.5123
168
+ - 0.583
169
+ - 0.7093
170
+ - 0.6461
171
+ - 0.6101
172
+ - 0.7316
173
+ - 0.7715
174
+ - 0.7681
175
+ - 0.8349
176
+ - 0.7815
177
+ - 0.7591
178
+ - 0.791
179
+ - 0.7433
180
+ - 0.7352
181
+ - 0.6869
182
+ - 0.6854
183
+ - 0.6623
184
+ - 0.5353
185
+ - 0.6492
186
+ - 0.6909
187
+ - 0.6106
188
+ - 0.5761
189
+ - 0.5236
190
+ - 0.5638
191
+ - 0.4054
192
+ - 0.4545
193
+ - 0.3407
194
+ - 0.3037
195
+ - 0.338
196
+ - 0.1599
197
+ - 0.1603
198
+ - 0.2741
199
+ - 0.213
200
+ - 0.1569
201
+ - 0.1911
202
+ - 0.2324
203
+ - 0.1586
204
+ - 0.1221
205
+ - 0.0341
206
+ - -0.0558
207
+ - 0.0553
208
+ - -0.1153
209
+ - -0.0933
210
+ - -0.1171
211
+ - -0.005
212
+ - -0.1519
213
+ - -0.1629
214
+ - -0.0522
215
+ - -0.0739
216
+ - -0.2069
217
+ - -0.2405
218
+ - -0.1244
219
+ - -0.2582
220
+ - -0.1361
221
+ - -0.1575
222
+ - -0.1442
223
+ - 0.0513
224
+ - -0.1567
225
+ - -0.2
226
+ - 0.0086
227
+ - -0.0698
228
+ - 0.1385
229
+ - 0.0941
230
+ - 0.1864
231
+ - 0.1225
232
+ - 0.1389
233
+ - 0.1382
234
+ - 0.167
235
+ - 0.1007
236
+ - 0.1444
237
+ - 0.0888
238
+ - 0.1998
239
+ - 0.228
240
+ - 0.2932
241
+ - 0.3047
242
+ spec_min:
243
+ - -4.7574
244
+ - -4.6783
245
+ - -4.6431
246
+ - -4.5832
247
+ - -4.539
248
+ - -4.6771
249
+ - -4.8089
250
+ - -4.7672
251
+ - -4.5784
252
+ - -4.7755
253
+ - -4.715
254
+ - -4.8919
255
+ - -4.8271
256
+ - -4.7389
257
+ - -4.6047
258
+ - -4.7759
259
+ - -4.6799
260
+ - -4.8201
261
+ - -4.7823
262
+ - -4.8262
263
+ - -4.7857
264
+ - -4.7545
265
+ - -4.9358
266
+ - -4.9733
267
+ - -5.1134
268
+ - -5.1395
269
+ - -4.9016
270
+ - -4.8434
271
+ - -5.0189
272
+ - -4.846
273
+ - -5.0529
274
+ - -4.951
275
+ - -5.0217
276
+ - -5.0049
277
+ - -5.1831
278
+ - -5.1445
279
+ - -5.1015
280
+ - -5.0281
281
+ - -4.9887
282
+ - -4.9916
283
+ - -4.9785
284
+ - -4.9071
285
+ - -4.9488
286
+ - -5.0342
287
+ - -4.9332
288
+ - -5.065
289
+ - -4.8924
290
+ - -5.0875
291
+ - -5.0483
292
+ - -5.0848
293
+ - -5.0655
294
+ - -5.0279
295
+ - -5.0015
296
+ - -5.0792
297
+ - -5.0636
298
+ - -5.2413
299
+ - -5.1421
300
+ - -5.171
301
+ - -5.3256
302
+ - -5.0511
303
+ - -5.1186
304
+ - -5.0057
305
+ - -5.0446
306
+ - -5.1173
307
+ - -5.0325
308
+ - -5.1085
309
+ - -5.0053
310
+ - -5.0755
311
+ - -5.1176
312
+ - -5.1004
313
+ - -5.2153
314
+ - -5.2757
315
+ - -5.3025
316
+ - -5.2867
317
+ - -5.2918
318
+ - -5.3328
319
+ - -5.2731
320
+ - -5.2985
321
+ - -5.24
322
+ - -5.2211
323
+ task_cls: tasks.tts.diffspeech.DiffSpeechTask
324
+ tb_log_interval: 100
325
+ test_ids:
326
+ - 0
327
+ - 1
328
+ - 2
329
+ - 3
330
+ - 4
331
+ - 5
332
+ - 6
333
+ - 7
334
+ - 8
335
+ - 9
336
+ - 10
337
+ - 11
338
+ - 12
339
+ - 13
340
+ - 14
341
+ - 15
342
+ - 16
343
+ - 17
344
+ - 18
345
+ - 19
346
+ - 68
347
+ - 70
348
+ - 74
349
+ - 87
350
+ - 110
351
+ - 172
352
+ - 190
353
+ - 215
354
+ - 231
355
+ - 294
356
+ - 316
357
+ - 324
358
+ - 402
359
+ - 422
360
+ - 485
361
+ - 500
362
+ - 505
363
+ - 508
364
+ - 509
365
+ - 519
366
+ test_input_yaml: ''
367
+ test_num: 100
368
+ test_set_name: test
369
+ timesteps: 100
370
+ train_set_name: train
371
+ train_sets: ''
372
+ use_energy_embed: true
373
+ use_gt_dur: false
374
+ use_gt_energy: false
375
+ use_gt_f0: false
376
+ use_pitch_embed: true
377
+ use_pos_embed: true
378
+ use_spk_embed: false
379
+ use_spk_id: false
380
+ use_uv: true
381
+ use_word_input: false
382
+ val_check_interval: 2000
383
+ valid_infer_interval: 10000
384
+ valid_monitor_key: val_loss
385
+ valid_monitor_mode: min
386
+ valid_set_name: valid
387
+ vocoder: HifiGAN
388
+ vocoder_ckpt: checkpoints/hifi_lj
389
+ warmup_updates: 4000
390
+ weight_decay: 0
391
+ win_size: 1024
392
+ word_dict_size: 10000
393
+ work_dir: checkpoints/0209_ds_1
checkpoints/diffspeech/model_ckpt_steps_160000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:503f81009a75c02d868253b6fb4f1411aeaa32308b101d7804447bc583636b83
3
+ size 168816223
docs/diffspeech.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run DiffSpeech
2
+
3
+ ## Quick Start
4
+
5
+ ### Install Dependencies
6
+
7
+ Install dependencies following [readme.md](../readme.md)
8
+
9
+ ### Set Config Path and Experiment Name
10
+
11
+ ```bash
12
+ export CONFIG_NAME=egs/datasets/audio/lj/ds.yaml
13
+ export MY_EXP_NAME=ds_exp
14
+ ```
15
+
16
+ ### Preprocess and binary dataset
17
+
18
+ Prepare dataset following [prepare_data.md](./prepare_data.md)
19
+
20
+ ### Prepare Vocoder
21
+
22
+ Prepare vocoder following [prepare_vocoder.md](./prepare_vocoder.md)
23
+
24
+ ## Training
25
+
26
+ First, you need a pre-trained FastSpeech2 checkpoint `chckpoints/fs2_exp/model_ckpt_steps_160000.ckpt`. To train a FastSpeech 2 model, run:
27
+
28
+ ```bash
29
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config egs/datasets/audio/lj/fs2_orig.yaml --exp_name fs2_exp --reset
30
+ ```
31
+
32
+ Then, run:
33
+
34
+ ```bash
35
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config $CONFIG_NAME --exp_name $MY_EXP_NAME --reset
36
+ ```
37
+
38
+ You can check the training and validation curves open Tensorboard via:
39
+
40
+ ```bash
41
+ tensorboard --logdir checkpoints/$MY_EXP_NAME
42
+ ```
43
+
44
+ ## Inference (Testing)
45
+
46
+ ```bash
47
+ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config $CONFIG_NAME --exp_name $MY_EXP_NAME --infer
48
+ ```
49
+
50
+ ## Citation
51
+
52
+ If you find this useful for your research, please use the following.
53
+
54
+ ```bib
55
+ @article{liu2021diffsinger,
56
+ title={Diffsinger: Singing voice synthesis via shallow diffusion mechanism},
57
+ author={Liu, Jinglin and Li, Chengxi and Ren, Yi and Chen, Feiyang and Liu, Peng and Zhao, Zhou},
58
+ journal={arXiv preprint arXiv:2105.02446},
59
+ volume={2},
60
+ year={2021}
61
+ }
62
+ ```
docs/prepare_vocoder.md CHANGED
@@ -26,7 +26,7 @@ export MY_EXP_NAME=my_hifigan_exp
26
  Prepare dataset following [prepare_data.md](./prepare_data.md).
27
 
28
  If you have run the `prepare_data` step of the acoustic
29
- model (e.g., FastSpeech 2 and PortaSpeech), you only need to binarize the dataset for the vocoder training:
30
 
31
  ```bash
32
  python data_gen/tts/runs/binarize.py --config $CONFIG_NAME
 
26
  Prepare dataset following [prepare_data.md](./prepare_data.md).
27
 
28
  If you have run the `prepare_data` step of the acoustic
29
+ model (e.g., PortaSpeech and DiffSpeech), you only need to binarize the dataset for the vocoder training:
30
 
31
  ```bash
32
  python data_gen/tts/runs/binarize.py --config $CONFIG_NAME
egs/datasets/audio/lj/ds.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - egs/egs_bases/tts/ds.yaml
3
+ - ./fs2_orig.yaml
4
+
5
+ fs2_ckpt: checkpoints/fs2_exp/model_ckpt_steps_160000.ckpt
6
+
7
+ # spec_min and spec_max are calculated on the training set.
8
+ spec_min: [ -4.7574, -4.6783, -4.6431, -4.5832, -4.5390, -4.6771, -4.8089, -4.7672,
9
+ -4.5784, -4.7755, -4.7150, -4.8919, -4.8271, -4.7389, -4.6047, -4.7759,
10
+ -4.6799, -4.8201, -4.7823, -4.8262, -4.7857, -4.7545, -4.9358, -4.9733,
11
+ -5.1134, -5.1395, -4.9016, -4.8434, -5.0189, -4.8460, -5.0529, -4.9510,
12
+ -5.0217, -5.0049, -5.1831, -5.1445, -5.1015, -5.0281, -4.9887, -4.9916,
13
+ -4.9785, -4.9071, -4.9488, -5.0342, -4.9332, -5.0650, -4.8924, -5.0875,
14
+ -5.0483, -5.0848, -5.0655, -5.0279, -5.0015, -5.0792, -5.0636, -5.2413,
15
+ -5.1421, -5.1710, -5.3256, -5.0511, -5.1186, -5.0057, -5.0446, -5.1173,
16
+ -5.0325, -5.1085, -5.0053, -5.0755, -5.1176, -5.1004, -5.2153, -5.2757,
17
+ -5.3025, -5.2867, -5.2918, -5.3328, -5.2731, -5.2985, -5.2400, -5.2211 ]
18
+ spec_max: [ -0.5982, -0.0778, 0.1205, 0.2747, 0.4657, 0.5123, 0.5830, 0.7093,
19
+ 0.6461, 0.6101, 0.7316, 0.7715, 0.7681, 0.8349, 0.7815, 0.7591,
20
+ 0.7910, 0.7433, 0.7352, 0.6869, 0.6854, 0.6623, 0.5353, 0.6492,
21
+ 0.6909, 0.6106, 0.5761, 0.5236, 0.5638, 0.4054, 0.4545, 0.3407,
22
+ 0.3037, 0.3380, 0.1599, 0.1603, 0.2741, 0.2130, 0.1569, 0.1911,
23
+ 0.2324, 0.1586, 0.1221, 0.0341, -0.0558, 0.0553, -0.1153, -0.0933,
24
+ -0.1171, -0.0050, -0.1519, -0.1629, -0.0522, -0.0739, -0.2069, -0.2405,
25
+ -0.1244, -0.2582, -0.1361, -0.1575, -0.1442, 0.0513, -0.1567, -0.2000,
26
+ 0.0086, -0.0698, 0.1385, 0.0941, 0.1864, 0.1225, 0.1389, 0.1382,
27
+ 0.1670, 0.1007, 0.1444, 0.0888, 0.1998, 0.2280, 0.2932, 0.3047 ]
28
+
29
+ max_tokens: 30000
egs/egs_bases/tts/ds.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: ./fs2_orig.yaml
2
+
3
+ # special configs for diffspeech
4
+ task_cls: tasks.tts.diffspeech.DiffSpeechTask
5
+ lr: 0.001
6
+ timesteps: 100
7
+ K_step: 71
8
+ diff_loss_type: l1
9
+ diff_decoder_type: 'wavenet'
10
+ schedule_type: 'linear'
11
+ max_beta: 0.06
12
+
13
+ ## model configs for diffspeech
14
+ dilation_cycle_length: 1
15
+ residual_layers: 20
16
+ residual_channels: 256
17
+ decay_steps: 50000
18
+ keep_bins: 80
19
+ #content_cond_steps: [ ] # [ 0, 10000 ]
20
+ #spk_cond_steps: [ ] # [ 0, 10000 ]
21
+ #gen_tgt_spk_id: -1
22
+
23
+
24
+
25
+ # training configs for diffspeech
26
+ #max_sentences: 48
27
+ #num_sanity_val_steps: 1
28
+ num_valid_plots: 10
29
+ use_gt_dur: false
30
+ use_gt_f0: false
31
+ #pitch_type: cwt
32
+ max_updates: 160000
inference/tts/ds.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from inference.tts.fs import FastSpeechInfer
3
+ # from modules.tts.fs2_orig import FastSpeech2Orig
4
+ from inference.tts.base_tts_infer import BaseTTSInfer
5
+ from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
6
+ from utils.commons.ckpt_utils import load_ckpt
7
+ from utils.commons.hparams import hparams
8
+
9
+
10
+ class DiffSpeechInfer(BaseTTSInfer):
11
+ def build_model(self):
12
+ dict_size = len(self.ph_encoder)
13
+ model = GaussianDiffusion(dict_size, self.hparams)
14
+ model.eval()
15
+ load_ckpt(model, hparams['work_dir'], 'model')
16
+ return model
17
+
18
+ def forward_model(self, inp):
19
+ sample = self.input_to_batch(inp)
20
+ txt_tokens = sample['txt_tokens'] # [B, T_t]
21
+ spk_id = sample.get('spk_ids')
22
+ with torch.no_grad():
23
+ output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True)
24
+ mel_out = output['mel_out']
25
+ wav_out = self.run_vocoder(mel_out)
26
+ wav_out = wav_out.cpu().numpy()
27
+ return wav_out[0]
28
+
29
+ if __name__ == '__main__':
30
+ DiffSpeechInfer.example_run()
modules/tts/commons/align_ops.py CHANGED
@@ -13,9 +13,8 @@ def mel2ph_to_mel2word(mel2ph, ph2word):
13
 
14
 
15
  def clip_mel2token_to_multiple(mel2token, frames_multiple):
16
- if mel2token.shape[1] % frames_multiple > 0:
17
- max_frames = mel2token.shape[1] // frames_multiple * frames_multiple
18
- mel2token = mel2token[:, :max_frames]
19
  return mel2token
20
 
21
 
 
13
 
14
 
15
  def clip_mel2token_to_multiple(mel2token, frames_multiple):
16
+ max_frames = mel2token.shape[1] // frames_multiple * frames_multiple
17
+ mel2token = mel2token[:, :max_frames]
 
18
  return mel2token
19
 
20
 
modules/tts/diffspeech/net.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from math import sqrt
8
+
9
+ Linear = nn.Linear
10
+ ConvTranspose2d = nn.ConvTranspose2d
11
+
12
+
13
+ class Mish(nn.Module):
14
+ def forward(self, x):
15
+ return x * torch.tanh(F.softplus(x))
16
+
17
+
18
+ class SinusoidalPosEmb(nn.Module):
19
+ def __init__(self, dim):
20
+ super().__init__()
21
+ self.dim = dim
22
+
23
+ def forward(self, x):
24
+ device = x.device
25
+ half_dim = self.dim // 2
26
+ emb = math.log(10000) / (half_dim - 1)
27
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
28
+ emb = x[:, None] * emb[None, :]
29
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
30
+ return emb
31
+
32
+
33
+ def Conv1d(*args, **kwargs):
34
+ layer = nn.Conv1d(*args, **kwargs)
35
+ nn.init.kaiming_normal_(layer.weight)
36
+ return layer
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ def __init__(self, encoder_hidden, residual_channels, dilation):
41
+ super().__init__()
42
+ self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
43
+ self.diffusion_projection = Linear(residual_channels, residual_channels)
44
+ self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
45
+ self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
46
+
47
+ def forward(self, x, conditioner, diffusion_step):
48
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
49
+ conditioner = self.conditioner_projection(conditioner)
50
+ y = x + diffusion_step
51
+
52
+ y = self.dilated_conv(y) + conditioner
53
+
54
+ gate, filter = torch.chunk(y, 2, dim=1)
55
+ y = torch.sigmoid(gate) * torch.tanh(filter)
56
+
57
+ y = self.output_projection(y)
58
+ residual, skip = torch.chunk(y, 2, dim=1)
59
+ return (x + residual) / sqrt(2.0), skip
60
+
61
+
62
+ class DiffNet(nn.Module):
63
+ def __init__(self, hparams):
64
+ super().__init__()
65
+ in_dims = hparams['audio_num_mel_bins']
66
+ self.encoder_hidden = hparams['hidden_size']
67
+ self.residual_layers = hparams['residual_layers']
68
+ self.residual_channels = hparams['residual_channels']
69
+ self.dilation_cycle_length = hparams['dilation_cycle_length']
70
+
71
+ self.input_projection = Conv1d(in_dims, self.residual_channels, 1)
72
+ self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
73
+ dim = self.residual_channels
74
+ self.mlp = nn.Sequential(
75
+ nn.Linear(dim, dim * 4),
76
+ Mish(),
77
+ nn.Linear(dim * 4, dim)
78
+ )
79
+ self.residual_layers = nn.ModuleList([
80
+ ResidualBlock(self.encoder_hidden, self.residual_channels, 2 ** (i % self.dilation_cycle_length))
81
+ for i in range(self.residual_layers)
82
+ ])
83
+ self.skip_projection = Conv1d(self.residual_channels, self.residual_channels, 1)
84
+ self.output_projection = Conv1d(self.residual_channels, in_dims, 1)
85
+ nn.init.zeros_(self.output_projection.weight)
86
+
87
+ def forward(self, spec, diffusion_step, cond):
88
+ """
89
+
90
+ :param spec: [B, 1, M, T]
91
+ :param diffusion_step: [B, 1]
92
+ :param cond: [B, M, T]
93
+ :return:
94
+ """
95
+ x = spec[:, 0]
96
+ x = self.input_projection(x) # x [B, residual_channel, T]
97
+
98
+ x = F.relu(x)
99
+ diffusion_step = self.diffusion_embedding(diffusion_step)
100
+ diffusion_step = self.mlp(diffusion_step)
101
+ skip = []
102
+ for layer_id, layer in enumerate(self.residual_layers):
103
+ x, skip_connection = layer(x, cond, diffusion_step)
104
+ skip.append(skip_connection)
105
+
106
+ x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
107
+ x = self.skip_projection(x)
108
+ x = F.relu(x)
109
+ x = self.output_projection(x) # [B, 80, T]
110
+ return x[:, None, :, :]
modules/tts/diffspeech/shallow_diffusion_tts.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from functools import partial
4
+ from inspect import isfunction
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from tqdm import tqdm
10
+
11
+ from modules.tts.fs2_orig import FastSpeech2Orig
12
+ from modules.tts.diffspeech.net import DiffNet
13
+ from modules.tts.commons.align_ops import expand_states
14
+
15
+
16
+ def exists(x):
17
+ return x is not None
18
+
19
+
20
+ def default(val, d):
21
+ if exists(val):
22
+ return val
23
+ return d() if isfunction(d) else d
24
+
25
+
26
+ # gaussian diffusion trainer class
27
+
28
+ def extract(a, t, x_shape):
29
+ b, *_ = t.shape
30
+ out = a.gather(-1, t)
31
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
32
+
33
+
34
+ def noise_like(shape, device, repeat=False):
35
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
36
+ noise = lambda: torch.randn(shape, device=device)
37
+ return repeat_noise() if repeat else noise()
38
+
39
+
40
+ def linear_beta_schedule(timesteps, max_beta=0.01):
41
+ """
42
+ linear schedule
43
+ """
44
+ betas = np.linspace(1e-4, max_beta, timesteps)
45
+ return betas
46
+
47
+
48
+ def cosine_beta_schedule(timesteps, s=0.008):
49
+ """
50
+ cosine schedule
51
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
52
+ """
53
+ steps = timesteps + 1
54
+ x = np.linspace(0, steps, steps)
55
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
56
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
57
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
58
+ return np.clip(betas, a_min=0, a_max=0.999)
59
+
60
+
61
+ beta_schedule = {
62
+ "cosine": cosine_beta_schedule,
63
+ "linear": linear_beta_schedule,
64
+ }
65
+
66
+
67
+ DIFF_DECODERS = {
68
+ 'wavenet': lambda hp: DiffNet(hp),
69
+ }
70
+
71
+
72
+ class AuxModel(FastSpeech2Orig):
73
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
74
+ f0=None, uv=None, energy=None, infer=False, **kwargs):
75
+ ret = {}
76
+ encoder_out = self.encoder(txt_tokens) # [B, T, C]
77
+ src_nonpadding = (txt_tokens > 0).float()[:, :, None]
78
+ style_embed = self.forward_style_embed(spk_embed, spk_id)
79
+
80
+ # add dur
81
+ dur_inp = (encoder_out + style_embed) * src_nonpadding
82
+ mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret)
83
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
84
+ decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph)
85
+
86
+ # add pitch and energy embed
87
+ if self.hparams['use_pitch_embed']:
88
+ pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
89
+ decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out)
90
+
91
+ # add pitch and energy embed
92
+ if self.hparams['use_energy_embed']:
93
+ energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
94
+ decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret)
95
+
96
+ # decoder input
97
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding
98
+ if self.hparams['dec_inp_add_noise']:
99
+ B, T, _ = decoder_inp.shape
100
+ z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device)
101
+ ret['adv_z'] = z
102
+ decoder_inp = torch.cat([decoder_inp, z], -1)
103
+ decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding
104
+ if kwargs['skip_decoder']:
105
+ return ret
106
+ ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
107
+ return ret
108
+
109
+
110
+ class GaussianDiffusion(nn.Module):
111
+ def __init__(self, dict_size, hparams, out_dims=None):
112
+ super().__init__()
113
+ self.hparams = hparams
114
+ out_dims = hparams['audio_num_mel_bins']
115
+ denoise_fn = DIFF_DECODERS[hparams['diff_decoder_type']](hparams)
116
+ timesteps = hparams['timesteps']
117
+ K_step = hparams['K_step']
118
+ loss_type = hparams['diff_loss_type']
119
+ spec_min = hparams['spec_min']
120
+ spec_max = hparams['spec_max']
121
+
122
+ self.denoise_fn = denoise_fn
123
+ self.fs2 = AuxModel(dict_size, hparams)
124
+ self.mel_bins = out_dims
125
+
126
+ if hparams['schedule_type'] == 'linear':
127
+ betas = linear_beta_schedule(timesteps, hparams['max_beta'])
128
+ else:
129
+ betas = cosine_beta_schedule(timesteps)
130
+
131
+ alphas = 1. - betas
132
+ alphas_cumprod = np.cumprod(alphas, axis=0)
133
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
134
+
135
+ timesteps, = betas.shape
136
+ self.num_timesteps = int(timesteps)
137
+ self.K_step = K_step
138
+ self.loss_type = loss_type
139
+
140
+ to_torch = partial(torch.tensor, dtype=torch.float32)
141
+
142
+ self.register_buffer('betas', to_torch(betas))
143
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
144
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
145
+
146
+ # calculations for diffusion q(x_t | x_{t-1}) and others
147
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
148
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
149
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
150
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
151
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
152
+
153
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
154
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
155
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
156
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
157
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
158
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
159
+ self.register_buffer('posterior_mean_coef1', to_torch(
160
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
161
+ self.register_buffer('posterior_mean_coef2', to_torch(
162
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
163
+
164
+ self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
165
+ self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
166
+
167
+ def q_mean_variance(self, x_start, t):
168
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
169
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
170
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
171
+ return mean, variance, log_variance
172
+
173
+ def predict_start_from_noise(self, x_t, t, noise):
174
+ return (
175
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
176
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
177
+ )
178
+
179
+ def q_posterior(self, x_start, x_t, t):
180
+ posterior_mean = (
181
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
182
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
183
+ )
184
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
185
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
186
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
187
+
188
+ def p_mean_variance(self, x, t, cond, clip_denoised: bool):
189
+ noise_pred = self.denoise_fn(x, t, cond=cond)
190
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
191
+
192
+ if clip_denoised:
193
+ x_recon.clamp_(-1., 1.)
194
+
195
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
196
+ return model_mean, posterior_variance, posterior_log_variance
197
+
198
+ @torch.no_grad()
199
+ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
200
+ b, *_, device = *x.shape, x.device
201
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
202
+ noise = noise_like(x.shape, device, repeat_noise)
203
+ # no noise when t == 0
204
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
205
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
206
+
207
+ def q_sample(self, x_start, t, noise=None):
208
+ noise = default(noise, lambda: torch.randn_like(x_start))
209
+ return (
210
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
211
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
212
+ )
213
+
214
+ def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
215
+ noise = default(noise, lambda: torch.randn_like(x_start))
216
+
217
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
218
+ x_recon = self.denoise_fn(x_noisy, t, cond)
219
+
220
+ if self.loss_type == 'l1':
221
+ if nonpadding is not None:
222
+ loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
223
+ else:
224
+ # print('are you sure w/o nonpadding?')
225
+ loss = (noise - x_recon).abs().mean()
226
+
227
+ elif self.loss_type == 'l2':
228
+ loss = F.mse_loss(noise, x_recon)
229
+ else:
230
+ raise NotImplementedError()
231
+
232
+ return loss
233
+
234
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
235
+ ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
236
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
237
+ ret = self.fs2(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
238
+ f0=f0, uv=uv, energy=energy, infer=infer, skip_decoder=(not infer), **kwargs)
239
+ # (txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
240
+ # skip_decoder=(not infer), infer=infer, **kwargs)
241
+ cond = ret['decoder_inp'].transpose(1, 2)
242
+
243
+ if not infer:
244
+ t = torch.randint(0, self.K_step, (b,), device=device).long()
245
+ x = ref_mels
246
+ x = self.norm_spec(x)
247
+ x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
248
+ ret['diff_loss'] = self.p_losses(x, t, cond)
249
+ # nonpadding = (mel2ph != 0).float()
250
+ # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
251
+ ret['mel_out'] = None
252
+ else:
253
+ ret['fs2_mel'] = ret['mel_out']
254
+ fs2_mels = ret['mel_out']
255
+ t = self.K_step
256
+ fs2_mels = self.norm_spec(fs2_mels)
257
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
258
+
259
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
260
+ if self.hparams.get('gaussian_start') is not None and self.hparams['gaussian_start']:
261
+ print('===> gaussian start.')
262
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
263
+ x = torch.randn(shape, device=device)
264
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
265
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
266
+ x = x[:, 0].transpose(1, 2)
267
+ ret['mel_out'] = self.denorm_spec(x)
268
+
269
+ return ret
270
+
271
+ def norm_spec(self, x):
272
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
273
+
274
+ def denorm_spec(self, x):
275
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
276
+
277
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
278
+ return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
279
+
280
+ def out2mel(self, x):
281
+ return x
tasks/tts/diffspeech.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from modules.tts.diffspeech.shallow_diffusion_tts import GaussianDiffusion
4
+ from tasks.tts.fs2_orig import FastSpeech2OrigTask
5
+
6
+ import utils
7
+ from utils.commons.hparams import hparams
8
+ from utils.commons.ckpt_utils import load_ckpt
9
+ from utils.audio.pitch.utils import denorm_f0
10
+
11
+
12
+ class DiffSpeechTask(FastSpeech2OrigTask):
13
+ def build_tts_model(self):
14
+ # get min and max
15
+ # import torch
16
+ # from tqdm import tqdm
17
+ # v_min = torch.ones([80]) * 100
18
+ # v_max = torch.ones([80]) * -100
19
+ # for i, ds in enumerate(tqdm(self.dataset_cls('train'))):
20
+ # v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max)
21
+ # v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min)
22
+ # if i % 100 == 0:
23
+ # print(i, v_min, v_max)
24
+ # print('final', v_min, v_max)
25
+ dict_size = len(self.token_encoder)
26
+ self.model = GaussianDiffusion(dict_size, hparams)
27
+ if hparams['fs2_ckpt'] != '':
28
+ load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
29
+ for k, v in self.model.fs2.named_parameters():
30
+ if 'predictor' not in k:
31
+ v.requires_grad = False
32
+ # or
33
+ # for k, v in self.model.fs2.named_parameters():
34
+ # v.requires_grad = False
35
+
36
+ def build_optimizer(self, model):
37
+ self.optimizer = optimizer = torch.optim.AdamW(
38
+ filter(lambda p: p.requires_grad, model.parameters()),
39
+ lr=hparams['lr'],
40
+ betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
41
+ weight_decay=hparams['weight_decay'])
42
+ return optimizer
43
+
44
+ def build_scheduler(self, optimizer):
45
+ return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
46
+
47
+ def run_model(self, sample, infer=False, *args, **kwargs):
48
+ txt_tokens = sample['txt_tokens'] # [B, T_t]
49
+ spk_embed = sample.get('spk_embed')
50
+ spk_id = sample.get('spk_ids')
51
+ if not infer:
52
+ target = sample['mels'] # [B, T_s, 80]
53
+ mel2ph = sample['mel2ph'] # [B, T_s]
54
+ f0 = sample.get('f0')
55
+ uv = sample.get('uv')
56
+ output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
57
+ ref_mels=target, f0=f0, uv=uv, infer=False)
58
+ losses = {}
59
+ if 'diff_loss' in output:
60
+ losses['mel'] = output['diff_loss']
61
+ self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
62
+ if hparams['use_pitch_embed']:
63
+ self.add_pitch_loss(output, sample, losses)
64
+ return losses, output
65
+ else:
66
+ use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
67
+ use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0'])
68
+ mel2ph, uv, f0 = None, None, None
69
+ if use_gt_dur:
70
+ mel2ph = sample['mel2ph']
71
+ if use_gt_f0:
72
+ f0 = sample['f0']
73
+ uv = sample['uv']
74
+ output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
75
+ ref_mels=None, f0=f0, uv=uv, infer=True)
76
+ return output
77
+
78
+ def save_valid_result(self, sample, batch_idx, model_out):
79
+ sr = hparams['audio_sample_rate']
80
+ f0_gt = None
81
+ # mel_out = model_out['mel_out']
82
+ if sample.get('f0') is not None:
83
+ f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu())
84
+ # self.plot_mel(batch_idx, sample['mels'], mel_out, f0s=f0_gt)
85
+ if self.global_step > 0:
86
+ # wav_pred = self.vocoder.spec2wav(mel_out[0].cpu(), f0=f0_gt)
87
+ # self.logger.add_audio(f'wav_val_{batch_idx}', wav_pred, self.global_step, sr)
88
+ # with gt duration
89
+ model_out = self.run_model(sample, infer=True, infer_use_gt_dur=True)
90
+ dur_info = self.get_plot_dur_info(sample, model_out)
91
+ del dur_info['dur_pred']
92
+ wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
93
+ self.logger.add_audio(f'wav_gdur_{batch_idx}', wav_pred, self.global_step, sr)
94
+ self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'diffmel_gdur_{batch_idx}',
95
+ dur_info=dur_info, f0s=f0_gt)
96
+ self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'][0], f'fs2mel_gdur_{batch_idx}',
97
+ dur_info=dur_info, f0s=f0_gt) # gt mel vs. fs2 mel
98
+
99
+ # with pred duration
100
+ if not hparams['use_gt_dur']:
101
+ model_out = self.run_model(sample, infer=True, infer_use_gt_dur=False)
102
+ dur_info = self.get_plot_dur_info(sample, model_out)
103
+ self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'][0], f'mel_pdur_{batch_idx}',
104
+ dur_info=dur_info, f0s=f0_gt)
105
+ wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu(), f0=f0_gt)
106
+ self.logger.add_audio(f'wav_pdur_{batch_idx}', wav_pred, self.global_step, sr)
107
+ # gt wav
108
+ if self.global_step <= hparams['valid_infer_interval']:
109
+ mel_gt = sample['mels'][0].cpu()
110
+ wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt)
111
+ self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, sr)