Spaces:
Sleeping
Sleeping
Sang-Hoon Lee
commited on
Commit
•
0164e4a
1
Parent(s):
8703869
Upload 70 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- LICENSE +21 -0
- Mels_preprocess.py +42 -0
- README.md +269 -13
- activations.py +120 -0
- alias_free_torch/__init__.py +6 -0
- alias_free_torch/act.py +28 -0
- alias_free_torch/filter.py +95 -0
- alias_free_torch/resample.py +49 -0
- app.py.py +236 -0
- attentions.py +313 -0
- commons.py +168 -0
- denoiser/config.json +28 -0
- denoiser/conformer.py +86 -0
- denoiser/g_best +3 -0
- denoiser/generator.py +193 -0
- denoiser/infer.py +33 -0
- denoiser/utils.py +55 -0
- example/reference_1.txt +1 -0
- example/reference_1.wav +0 -0
- example/reference_2.txt +1 -0
- example/reference_2.wav +0 -0
- example/reference_3.txt +1 -0
- example/reference_3.wav +0 -0
- example/reference_4.txt +1 -0
- example/reference_4.wav +0 -0
- hierspeechpp_speechsynthesizer.py +716 -0
- inference.py +220 -0
- inference.sh +12 -0
- inference_speechsr.py +94 -0
- inference_vc.py +250 -0
- inference_vc.sh +11 -0
- logs/hierspeechpp_eng_kor/config.json +67 -0
- logs/hierspeechpp_eng_kor/hierspeechpp_v1_ckpt.pth +3 -0
- logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth +3 -0
- logs/ttv_libritts_v1/config.json +55 -0
- logs/ttv_libritts_v1/ttv_lt960_ckpt.pth +3 -0
- modules.py +534 -0
- requirements.txt +15 -0
- results/reference_1.wav +0 -0
- results/reference_2.wav +0 -0
- results/reference_3.wav +0 -0
- results/reference_4.wav +0 -0
- speechsr24k/G_340000.pth +3 -0
- speechsr24k/config.json +49 -0
- speechsr24k/speechsr.py +253 -0
- speechsr48k/G_100000.pth +3 -0
- speechsr48k/config.json +48 -0
- speechsr48k/speechsr.py +252 -0
- styleencoder.py +91 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
denoiser/g_best filter=lfs diff=lfs merge=lfs -text
|
37 |
+
ttv_v1/monotonic_align/build/temp.linux-x86_64-3.7/core.o filter=lfs diff=lfs merge=lfs -text
|
38 |
+
ttv_v1/monotonic_align/build/temp.linux-x86_64-3.8/core.o filter=lfs diff=lfs merge=lfs -text
|
39 |
+
ttv_v1/monotonic_align/build/temp.linux-x86_64-3.9/core.o filter=lfs diff=lfs merge=lfs -text
|
40 |
+
ttv_v1/monotonic_align/monotonic_align/core.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
41 |
+
ttv_v1/monotonic_align/monotonic_align/core.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Sang-Hoon Lee
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Mels_preprocess.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
np.random.seed(1234)
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchaudio.transforms import MelSpectrogram, Spectrogram, MelScale
|
7 |
+
|
8 |
+
class MelSpectrogramFixed(torch.nn.Module):
|
9 |
+
"""In order to remove padding of torchaudio package + add log scale."""
|
10 |
+
|
11 |
+
def __init__(self, **kwargs):
|
12 |
+
super(MelSpectrogramFixed, self).__init__()
|
13 |
+
self.torchaudio_backend = MelSpectrogram(**kwargs)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
outputs = torch.log(self.torchaudio_backend(x) + 0.001)
|
17 |
+
|
18 |
+
return outputs[..., :-1]
|
19 |
+
|
20 |
+
class SpectrogramFixed(torch.nn.Module):
|
21 |
+
"""In order to remove padding of torchaudio package + add log10 scale."""
|
22 |
+
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
super(SpectrogramFixed, self).__init__()
|
25 |
+
self.torchaudio_backend = Spectrogram(**kwargs)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
outputs = self.torchaudio_backend(x)
|
29 |
+
|
30 |
+
return outputs[..., :-1]
|
31 |
+
|
32 |
+
class MelfilterFixed(torch.nn.Module):
|
33 |
+
"""In order to remove padding of torchaudio package + add log10 scale."""
|
34 |
+
|
35 |
+
def __init__(self, **kwargs):
|
36 |
+
super(MelfilterFixed, self).__init__()
|
37 |
+
self.torchaudio_backend = MelScale(**kwargs)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
outputs = torch.log(self.torchaudio_backend(x) + 0.001)
|
41 |
+
|
42 |
+
return outputs
|
README.md
CHANGED
@@ -1,13 +1,269 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HierSpeech++: Bridging the Gap between Semantic and Acoustic Representation by Hierarchical Variational Inference for Zero-shot Speech Synthesis
|
2 |
+
The official implementation of HierSpeech2 | [Paper]() | [Demo page](https://sh-lee-prml.github.io/HierSpeechpp-demo/) | [Checkpoint](https://drive.google.com/drive/folders/1-L_90BlCkbPyKWWHTUjt5Fsu3kz0du0w?usp=sharing) |
|
3 |
+
|
4 |
+
**Sang-Hoon Lee, Ha-Yeong Choi, Seung-Bin Kim, Seong-Whan Lee**
|
5 |
+
|
6 |
+
Department of Artificial Intelligence, Korea University, Seoul, Korea
|
7 |
+
|
8 |
+
## Abstract
|
9 |
+
![image](https://github.com/sh-lee-prml/HierSpeechpp/assets/56749640/732bc183-bf11-4f32-84a9-e9eab8190e1a)
|
10 |
+
<details>
|
11 |
+
<summary> [Abs.] Sorry for too long abstract😅 </summary>
|
12 |
+
|
13 |
+
|
14 |
+
Recently, large-scale language models (LLM)-based speech synthesis has shown a significant performance in zero-shot speech synthesis. However, they require a large-scale data and even suffer from the same limitation of previous autoregressive speech models such as slow inference speed and lack of robustness. Following the previous powerful end-to-end text-to-speech framework of VITS (but now that's what we call classical), this paper proposes HierSpeech++, a fast and strong zero-shot speech synthesizer for text-to-speech (TTS) and voice conversion (VC). In the previous our works (HierSpeech and HierVST), we verified that hierarchical speech synthesis frameworks could significantly improve the robustness and expressiveness of the synthetic speech by adopting hierarchical variational autoencoder and leveraging self-supervised speech represenation as an additional linguistic information to bridge an information gap between text and speech. In this work, we once again significantly improve the naturalness and speaker similarity of the synthetic speech even in the zero-shot speech synthesis scenarios. We first introduce multi-audio acoustic encoder for the enhanced acoustics posterior, and adopt a hierarchical adaptive waveform generator with conditional/unconditional generation. Second, we additionally utilize a F0 information and introduce source-filter theory-based multi-path semantic encoder for speaker-agnostic and speaker-related semantic representation. We also leverage hierarchical variational autoencoder to connect multiple representations, and present a BiT-Flow which is a bidirectional normalizing flow Transformer networks with AdaLN-Zero for better speaker adaptation and train-inference mismatch reduction. Without any text transcripts, we only utilize the speech dataset to train the speech synthesizer for data flexibility. For text-to-speech, we introduce text-to-vec (TTV) frameworks to generate a self-supervised speech representation and F0 representation from text representation and prosody prompt. Then, the speech synthesizer of HierSpeech++ generates the speech from generated vector, F0, and voice prompt. In addition, we propose the high-efficient speech super-resolution framework which can upsample the waveform audio from 16 kHz to 48 kHz, and this facilitate training the speech synthesizer in that we can use easily available low-resolution (16 kHz) speech data for scaling-up. The experimental results demonstrated that hierarchical variational autoencoder could be a strong zero-shot speech synthesizer by beating LLM-based models and diffusion-based models for TTS and VC tasks. Furthermore, we also verify the data efficiency in that our model trained with a small dataset still shows a better performance in both naturalness and similarity than other models trained with large-scale dataset. Moreover, we achieve the first human-level quality in zero-shot speech synthesis.
|
15 |
+
</details>
|
16 |
+
|
17 |
+
This repository contains:
|
18 |
+
|
19 |
+
- 🪐 A PyTorch implementation of HierSpeech++ (TTV, Hierarchical Speech Synthesizer, SpeechSR)
|
20 |
+
- ⚡️ Pre-trained HierSpeech++ models trained on LibriTTS (Train-460, Train-960, and more dataset)
|
21 |
+
|
22 |
+
<!--
|
23 |
+
- 💥 A Colab notebook for running pre-trained HierSpeech++ models (Soon..)
|
24 |
+
🛸 A HierSpeech++ training script (Will be released soon)
|
25 |
+
-->
|
26 |
+
## Previous Our Works
|
27 |
+
- [1] HierSpeech: Bridging the Gap between Text and Speech by Hierarchical Variational Inference using Self-supervised Representations for Speech Synthesis
|
28 |
+
- [2] HierVST: Hierarchical Adaptive Zero-shot Voice Style Transfer
|
29 |
+
|
30 |
+
This paper is an extenstion version of above papers.
|
31 |
+
|
32 |
+
## Todo
|
33 |
+
### Hierarchical Speech Synthesizer
|
34 |
+
- [x] HierSpeechpp-Backbone
|
35 |
+
<!--
|
36 |
+
- [ ] HierSpeech-Lite (Fast and Efficient Zero-shot Speech Synthesizer)
|
37 |
+
- [ ] HierSinger (Zero-shot Singing Voice Synthesizer)
|
38 |
+
- [ ] HierSpeech2-24k-Large-Full (For High-resolutional and High-quality Speech Synthesizer)
|
39 |
+
- [ ] HierSpeech2-48k-Large-Full (For Industrial-level High-resolution and High-quality Speech Synthesizer)
|
40 |
+
-->
|
41 |
+
### Text-to-Vec (TTV)
|
42 |
+
- [x] TTV-v1 (LibriTTS-train-960)
|
43 |
+
- [ ] TTV-v2 (We are currently training a multi-lingual TTV model)
|
44 |
+
<!--
|
45 |
+
- [ ] Hierarchical Text-to-Vec (For Much More Expressive Text-to-Speech)
|
46 |
+
-->
|
47 |
+
### Speech Super-resolution (16k --> 24k or 48k)
|
48 |
+
- [x] SpeechSR-24k
|
49 |
+
- [x] SpeechSR-48k
|
50 |
+
### Training code (Will be released after paper acceptance)
|
51 |
+
- [ ] TTV
|
52 |
+
- [ ] Hierarchical Speech Synthesizer
|
53 |
+
- [ ] SpeechSR
|
54 |
+
## Getting Started
|
55 |
+
|
56 |
+
### Pre-requisites
|
57 |
+
0. Pytorch >=1.13 and torchaudio >= 0.13
|
58 |
+
1. Install requirements
|
59 |
+
```
|
60 |
+
pip install -r requirements.txt
|
61 |
+
```
|
62 |
+
2. Install Phonemizer
|
63 |
+
```
|
64 |
+
pip install phonemizer
|
65 |
+
sudo apt-get install espeak-ng
|
66 |
+
```
|
67 |
+
|
68 |
+
## Checkpoint [[Download]](https://drive.google.com/drive/folders/1-L_90BlCkbPyKWWHTUjt5Fsu3kz0du0w?usp=sharing)
|
69 |
+
### Hierarchical Speech Synthesizer
|
70 |
+
| Model |Sampling Rate|Params|Dataset|Hour|Speaker|Checkpoint|
|
71 |
+
|------|:---:|:---:|:---:|:---:|:---:|:---:|
|
72 |
+
| HierSpeech2|16 kHz|97M| LibriTTS (train-460) |245|1,151|[[Download]](https://drive.google.com/drive/folders/14FTu0ZWux0zAD7ev4O1l6lKslQcdmebL?usp=sharing)|
|
73 |
+
| HierSpeech2|16 kHz|97M| LibriTTS (train-960) |555|2,311|[[Download]](https://drive.google.com/drive/folders/1sFQP-8iS8z9ofCkE7szXNM_JEy4nKg41?usp=drive_link)|
|
74 |
+
| HierSpeech2|16 kHz|97M| LibriTTS (train-960), Libri-light (Small, Medium), Expresso, MMS(Kor), NIKL(Kor)|2,796| 7,299 |[[Download]](https://drive.google.com/drive/folders/14jaDUBgrjVA7bCODJqAEirDwRlvJe272?usp=drive_link)|
|
75 |
+
|
76 |
+
<!--
|
77 |
+
| HierSpeech2-Lite|16 kHz|-| LibriTTS (train-960)) |-|
|
78 |
+
| HierSpeech2-Lite|16 kHz|-| LibriTTS (train-960) NIKL, AudioBook-Korean) |-|
|
79 |
+
| HierSpeech2-Large-CL|16 kHz|200M| LibriTTS (train-960), Libri-Light, NIKL, AudioBook-Korean, Japanese, Chinese, CSS, MLS) |-|
|
80 |
+
-->
|
81 |
+
|
82 |
+
### TTV
|
83 |
+
| Model |Language|Params|Dataset|Hour|Speaker|Checkpoint|
|
84 |
+
|------|:---:|:---:|:---:|:---:|:---:|:---:|
|
85 |
+
| TTV |Eng|107M| LibriTTS (train-960) |555|2,311|[[Download]](https://drive.google.com/drive/folders/1QiFFdPhqhiLFo8VXc0x7cFHKXArx7Xza?usp=drive_link)|
|
86 |
+
|
87 |
+
|
88 |
+
<!--
|
89 |
+
| TTV |Kor|100M| NIKL |114|118|-|
|
90 |
+
| TTV |Eng|50M| LibriTTS (train-960) |555|2,311|-|
|
91 |
+
| TTV-Large |Eng|100M| LibriTTS (train-960) |555|2,311|-|
|
92 |
+
| TTV-Lite |Eng|10M| LibriTTS (train-960) |555|2,311|-|
|
93 |
+
| TTV |Kor|50M| NIKL |114|118|-|
|
94 |
+
-->
|
95 |
+
### SpeechSR
|
96 |
+
| Model |Sampling Rate|Params|Dataset |Checkpoint|
|
97 |
+
|------|:---:|:---:|:---:|:---:|
|
98 |
+
| SpeechSR-24k |16kHz --> 24 kHz|0.13M| LibriTTS (train-960), MMS (Kor) |[speechsr24k](https://github.com/sh-lee-prml/HierSpeechpp/blob/main/speechsr24k/G_340000.pth)|
|
99 |
+
| SpeechSR-48k |16kHz --> 48 kHz|0.13M| MMS (Kor), Expresso (Eng), VCTK (Eng)|[speechsr48k](https://github.com/sh-lee-prml/HierSpeechpp/blob/main/speechsr48k/G_100000.pth)|
|
100 |
+
|
101 |
+
## Text-to-Speech
|
102 |
+
```
|
103 |
+
sh inference.sh
|
104 |
+
|
105 |
+
# --ckpt "logs/hierspeechpp_libritts460/hierspeechpp_lt460_ckpt.pth" \ LibriTTS-460
|
106 |
+
# --ckpt "logs/hierspeechpp_libritts960/hierspeechpp_lt960_ckpt.pth" \ LibriTTS-960
|
107 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v1_ckpt.pth" \ Large_v1 epoch 60 (paper version)
|
108 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \ Large_v2 epoch 110 (08. Nov. 2023)
|
109 |
+
|
110 |
+
CUDA_VISIBLE_DEVICES=0 python3 inference.py \
|
111 |
+
--ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \
|
112 |
+
--ckpt_text2w2v "logs/ttv_libritts_v1/ttv_lt960_ckpt.pth" \
|
113 |
+
--output_dir "tts_results_eng_kor_v2" \
|
114 |
+
--noise_scale_vc "0.333" \
|
115 |
+
--noise_scale_ttv "0.333" \
|
116 |
+
--denoise_ratio "0"
|
117 |
+
|
118 |
+
```
|
119 |
+
- For better robustness, we recommend a noise_scale of 0.333
|
120 |
+
- For better expressiveness, we recommend a noise_scale of 0.667
|
121 |
+
- Find your best parameters for your style prompt 😵
|
122 |
+
### Noise Control
|
123 |
+
```
|
124 |
+
# without denoiser
|
125 |
+
--denoise_ratio "0"
|
126 |
+
# with denoiser
|
127 |
+
--denoise_ratio "1"
|
128 |
+
# Mixup (Recommended 0.6~0.8)
|
129 |
+
--denoise_rate "0.8"
|
130 |
+
```
|
131 |
+
## Voice Conversion
|
132 |
+
- This method only utilize a hierarchical speech synthesizer for voice conversion.
|
133 |
+
```
|
134 |
+
sh inference_vc.sh
|
135 |
+
|
136 |
+
# --ckpt "logs/hierspeechpp_libritts460/hierspeechpp_lt460_ckpt.pth" \ LibriTTS-460
|
137 |
+
# --ckpt "logs/hierspeechpp_libritts960/hierspeechpp_lt960_ckpt.pth" \ LibriTTS-960
|
138 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v1_ckpt.pth" \ Large_v1 epoch 60 (paper version)
|
139 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \ Large_v2 epoch 110 (08. Nov. 2023)
|
140 |
+
|
141 |
+
CUDA_VISIBLE_DEVICES=0 python3 inference_vc.py \
|
142 |
+
--ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \
|
143 |
+
--output_dir "vc_results_eng_kor_v2" \
|
144 |
+
--noise_scale_vc "0.333" \
|
145 |
+
--noise_scale_ttv "0.333" \
|
146 |
+
--denoise_ratio "0"
|
147 |
+
```
|
148 |
+
- For better robustness, we recommend a noise_scale of 0.333
|
149 |
+
- For better expressiveness, we recommend a noise_scale of 0.667
|
150 |
+
- Find your best parameters for your style prompt 😵
|
151 |
+
- Voice Conversion is vulnerable to noisy target prompt so we recommend to utilize a denoiser with noisy prompt
|
152 |
+
- For noisy source speech, a wrong F0 may be extracted by YAPPT resulting in a quality degradation.
|
153 |
+
|
154 |
+
|
155 |
+
## Speech Super-resolution
|
156 |
+
- SpeechSR-24k and SpeechSR-48 are provided in TTS pipeline. If you want to use SpeechSR only, please refer inference_speechsr.py.
|
157 |
+
- If you change the output resolution, add this
|
158 |
+
```
|
159 |
+
--output_sr "48000" # Default
|
160 |
+
--output_sr "24000" #
|
161 |
+
--output_sr "16000" # without super-resolution.
|
162 |
+
```
|
163 |
+
## Speech Denoising for Noise-free Speech Synthesis (Only used in Speaker Encoder during Inference)
|
164 |
+
- For denoised style prompt, we utilize a denoiser [(MP-SENet)](https://github.com/yxlu-0102/MP-SENet).
|
165 |
+
- When using a long reference audio, there is an out-of-memory issue with this model so we have a plan to learn a memory efficient speech denoiser in the future.
|
166 |
+
- If you have a problem, we recommend to use a clean reference audio or denoised audio before TTS pipeline or denoise the audio with cpu (but this will be slow😥).
|
167 |
+
|
168 |
+
## TTV-v2
|
169 |
+
- TTV-v1 is a simple model which is very slightly modified from VITS. Although this simple TTV could synthesize a speech with high-quality and high speaker similarity, we thought that there is room for improvement in terms of expressiveness such as prosody modeling.
|
170 |
+
- For TTV-v2, we modify some components and training process (Model size: 107M --> 278M)
|
171 |
+
1. Intermediate hidden size: 256 --> 384
|
172 |
+
2. Loss masking for wav2vec reconstruction loss (I left out masking the loss for zero-padding sequences😥)
|
173 |
+
3. For long sentence generation, we finetune the model with full LibriTTS-train dataset without data filtering (Decrease the learning rate to 2e-5 with batch size of 8 per gpus)
|
174 |
+
|
175 |
+
## GAN VS Diffusion
|
176 |
+
<details>
|
177 |
+
<summary> [Read More] </summary>
|
178 |
+
We think that we could not confirm which is better yet. There are many advatanges for each model so you can utilize each model for your own purposes and each study must be actively conducted simultaneously.
|
179 |
+
|
180 |
+
### GAN (Specifically, GAN-based End-to-End Speech Synthesis Models)
|
181 |
+
- (pros) Fast Inference Speed
|
182 |
+
- (pros) High-quality Audio
|
183 |
+
- (cons) Slow Training Speed (Over 7~20 Days)
|
184 |
+
- (cons) Lower Voice Style Transfer Performance than Diffusion Models
|
185 |
+
- (cons) Perceptually High-quality but Over-smoothed Audio because of Information Bottleneck by the sampling from the low-dimensional Latent Variable
|
186 |
+
|
187 |
+
### Diffusion (Diffusion-based Mel-spectrogram Generation Models)
|
188 |
+
- (pros) Fast Training Speed (within 3 Days)
|
189 |
+
- (pros) High-quality Voice Style Transfer
|
190 |
+
- (cons) Slow Inference Speed
|
191 |
+
- (cons) Lower Audio quality than End-to-End Speech Synthesis Models
|
192 |
+
|
193 |
+
### (In this wors) Our Approaches for GAN-based End-to-End Speech Synthesis Models
|
194 |
+
- Improving Voice Style Transfer Performance in End-to-End Speech Synthesis Models for OOD (Zero-shot Voice Style Transfer for Novel Speaker)
|
195 |
+
- Improving the Audio Quality beyond Perceptal Quality for Much more High-fidelity Audio Generation
|
196 |
+
|
197 |
+
### (Our other works) Diffusion-based Mel-spectrogram Generation Models
|
198 |
+
- DDDM-VC: Disentangled Denoising Diffusion Models for High-quality and High-diversity Speech Synthesis Models
|
199 |
+
- Diff-hierVC: Hierarhical Diffusion-based Speech Synthesis Model with Diffusion-based Pitch Modeling
|
200 |
+
|
201 |
+
### Our Goals
|
202 |
+
- Integrating each model for High-quality, High-diversity and High-fidelity Speech Synthesis Models
|
203 |
+
</details>
|
204 |
+
|
205 |
+
## LLM-based Models
|
206 |
+
We hope to compare LLM-based models for zero-shot TTS baselines. However, there is no public-available official implementation of LLM-based TTS models. Unfortunately, unofficial models have a poor performance in zero-shot TTS so we hope they will release their model for a fair comparison and reproducibility and for our speech community. THB I could not stand the inference speed almost 1,000 times slower than e2e models It takes 5 days to synthesize the full sentences of LibriTTS-test subsets. Even, the audio quality is so bad. I hope they will release their official source code soon.
|
207 |
+
|
208 |
+
In my very personal opinion, VITS is still the best TTS model I have ever seen. But, I acknowledge that LLM-based models have much powerful potential for their creative generative performance from the large-scale dataset but not now.
|
209 |
+
|
210 |
+
## Limitation of our work
|
211 |
+
- Slow training speed and Relatively large model size (Compared with VITS) --> Future work: Light-weight and Fast training pipeline and much larger model...
|
212 |
+
- Could not generate realistic background sound --> Future work: adding audio generation part by disentangling speech and sound.
|
213 |
+
- Could not generate a speech from a too long sentence becauase of our training setting. We see increasing max length could improve the model performance. However, we do not have GPUs with 80 GB 😢
|
214 |
+
```
|
215 |
+
# Data Filtering for limited computation resource.
|
216 |
+
wav_min = 32
|
217 |
+
wav_max = 600 # 12s
|
218 |
+
text_min = 1
|
219 |
+
text_max = 200
|
220 |
+
```
|
221 |
+
TTV v2 may reduce this issue significantly...!
|
222 |
+
|
223 |
+
## Results [[Download]](https://drive.google.com/drive/folders/1xCrZQy9s5MT38RMQxKAtkoWUgxT5qYYW?usp=sharing)
|
224 |
+
We have attached all samples from LibriTTS test-clean and test-other.
|
225 |
+
|
226 |
+
## Reference
|
227 |
+
<details>
|
228 |
+
<summary> [Read More] </summary>
|
229 |
+
|
230 |
+
### Our Previous Works
|
231 |
+
- HierSpeech/HierSpeech-U for Hierarchical Speech Synthesis Framework: https://openreview.net/forum?id=awdyRVnfQKX
|
232 |
+
- HierVST for Baseline Speech Backbone: https://www.isca-speech.org/archive/interspeech_2023/lee23i_interspeech.html
|
233 |
+
- DDDM-VC: https://dddm-vc.github.io/
|
234 |
+
- Diff-HierVC: https://diff-hiervc.github.io/
|
235 |
+
|
236 |
+
### Baseline Model
|
237 |
+
- VITS: https://github.com/jaywalnut310/vits
|
238 |
+
- NaturalSpeech
|
239 |
+
- NANSY for Audio Perturbation: https://github.com/revsic/torch-nansy
|
240 |
+
- Speech Resynthesis: https://github.com/facebookresearch/speech-resynthesis
|
241 |
+
|
242 |
+
### Waveform Generator for High-quality Audio Generation
|
243 |
+
- HiFi-GAN: https://github.com/jik876/hifi-gan
|
244 |
+
- BigVGAN for High-quality Generator: https://arxiv.org/abs/2206.04658
|
245 |
+
- UnivNET: https://github.com/mindslab-ai/univnet
|
246 |
+
- EnCodec: https://github.com/facebookresearch/encodec
|
247 |
+
|
248 |
+
### Self-supervised Speech Model
|
249 |
+
- Wav2Vec 2.0: https://arxiv.org/abs/2006.11477
|
250 |
+
- XLS-R: https://huggingface.co/facebook/wav2vec2-xls-r-300m
|
251 |
+
- MMS: https://huggingface.co/facebook/facebook/mms-300m
|
252 |
+
|
253 |
+
### Other Large Language Model based Speech Synthesis Model
|
254 |
+
- VALL-E & VALL-E-X
|
255 |
+
- SPEAR-TTS
|
256 |
+
- NaturalSpeech 2
|
257 |
+
- Make-a-Voice
|
258 |
+
- MEGA-TTS & MEGA-TTS 2
|
259 |
+
- UniAudio
|
260 |
+
|
261 |
+
### AdaLN-zero
|
262 |
+
- Dit: https://github.com/facebookresearch/DiT
|
263 |
+
|
264 |
+
Thanks for all nice works.
|
265 |
+
</details>
|
266 |
+
|
267 |
+
## LICENSE
|
268 |
+
- Code in this repo: MIT License
|
269 |
+
- Model Weights: CC-BY-NC-4.0 license
|
activations.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, sin, pow
|
6 |
+
from torch.nn import Parameter
|
7 |
+
|
8 |
+
|
9 |
+
class Snake(nn.Module):
|
10 |
+
'''
|
11 |
+
Implementation of a sine-based periodic activation function
|
12 |
+
Shape:
|
13 |
+
- Input: (B, C, T)
|
14 |
+
- Output: (B, C, T), same shape as the input
|
15 |
+
Parameters:
|
16 |
+
- alpha - trainable parameter
|
17 |
+
References:
|
18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
19 |
+
https://arxiv.org/abs/2006.08195
|
20 |
+
Examples:
|
21 |
+
>>> a1 = snake(256)
|
22 |
+
>>> x = torch.randn(256)
|
23 |
+
>>> x = a1(x)
|
24 |
+
'''
|
25 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
26 |
+
'''
|
27 |
+
Initialization.
|
28 |
+
INPUT:
|
29 |
+
- in_features: shape of the input
|
30 |
+
- alpha: trainable parameter
|
31 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
32 |
+
alpha will be trained along with the rest of your model.
|
33 |
+
'''
|
34 |
+
super(Snake, self).__init__()
|
35 |
+
self.in_features = in_features
|
36 |
+
|
37 |
+
# initialize alpha
|
38 |
+
self.alpha_logscale = alpha_logscale
|
39 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
40 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
41 |
+
else: # linear scale alphas initialized to ones
|
42 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
43 |
+
|
44 |
+
self.alpha.requires_grad = alpha_trainable
|
45 |
+
|
46 |
+
self.no_div_by_zero = 0.000000001
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
'''
|
50 |
+
Forward pass of the function.
|
51 |
+
Applies the function to the input elementwise.
|
52 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
53 |
+
'''
|
54 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
55 |
+
if self.alpha_logscale:
|
56 |
+
alpha = torch.exp(alpha)
|
57 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
58 |
+
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class SnakeBeta(nn.Module):
|
63 |
+
'''
|
64 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
65 |
+
Shape:
|
66 |
+
- Input: (B, C, T)
|
67 |
+
- Output: (B, C, T), same shape as the input
|
68 |
+
Parameters:
|
69 |
+
- alpha - trainable parameter that controls frequency
|
70 |
+
- beta - trainable parameter that controls magnitude
|
71 |
+
References:
|
72 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
73 |
+
https://arxiv.org/abs/2006.08195
|
74 |
+
Examples:
|
75 |
+
>>> a1 = snakebeta(256)
|
76 |
+
>>> x = torch.randn(256)
|
77 |
+
>>> x = a1(x)
|
78 |
+
'''
|
79 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
80 |
+
'''
|
81 |
+
Initialization.
|
82 |
+
INPUT:
|
83 |
+
- in_features: shape of the input
|
84 |
+
- alpha - trainable parameter that controls frequency
|
85 |
+
- beta - trainable parameter that controls magnitude
|
86 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
87 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
88 |
+
alpha will be trained along with the rest of your model.
|
89 |
+
'''
|
90 |
+
super(SnakeBeta, self).__init__()
|
91 |
+
self.in_features = in_features
|
92 |
+
|
93 |
+
# initialize alpha
|
94 |
+
self.alpha_logscale = alpha_logscale
|
95 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
96 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
97 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
98 |
+
else: # linear scale alphas initialized to ones
|
99 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
100 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
101 |
+
|
102 |
+
self.alpha.requires_grad = alpha_trainable
|
103 |
+
self.beta.requires_grad = alpha_trainable
|
104 |
+
|
105 |
+
self.no_div_by_zero = 0.000000001
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
'''
|
109 |
+
Forward pass of the function.
|
110 |
+
Applies the function to the input elementwise.
|
111 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
112 |
+
'''
|
113 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
114 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
115 |
+
if self.alpha_logscale:
|
116 |
+
alpha = torch.exp(alpha)
|
117 |
+
beta = torch.exp(beta)
|
118 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
119 |
+
|
120 |
+
return x
|
alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
6 |
+
from .act import *
|
alias_free_torch/act.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from .resample import UpSample1d, DownSample1d
|
6 |
+
|
7 |
+
|
8 |
+
class Activation1d(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
activation,
|
11 |
+
up_ratio: int = 2,
|
12 |
+
down_ratio: int = 2,
|
13 |
+
up_kernel_size: int = 12,
|
14 |
+
down_kernel_size: int = 12):
|
15 |
+
super().__init__()
|
16 |
+
self.up_ratio = up_ratio
|
17 |
+
self.down_ratio = down_ratio
|
18 |
+
self.act = activation
|
19 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
20 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
21 |
+
|
22 |
+
# x: [B,C,T]
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.upsample(x)
|
25 |
+
x = self.act(x)
|
26 |
+
x = self.downsample(x)
|
27 |
+
|
28 |
+
return x
|
alias_free_torch/filter.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if 'sinc' in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(x == 0,
|
21 |
+
torch.tensor(1., device=x.device, dtype=x.dtype),
|
22 |
+
torch.sin(math.pi * x) / math.pi / x)
|
23 |
+
|
24 |
+
|
25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
+
# LICENSE is in incl_licenses directory.
|
28 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
29 |
+
even = (kernel_size % 2 == 0)
|
30 |
+
half_size = kernel_size // 2
|
31 |
+
|
32 |
+
#For kaiser window
|
33 |
+
delta_f = 4 * half_width
|
34 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35 |
+
if A > 50.:
|
36 |
+
beta = 0.1102 * (A - 8.7)
|
37 |
+
elif A >= 21.:
|
38 |
+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
39 |
+
else:
|
40 |
+
beta = 0.
|
41 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42 |
+
|
43 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
44 |
+
if even:
|
45 |
+
time = (torch.arange(-half_size, half_size) + 0.5)
|
46 |
+
else:
|
47 |
+
time = torch.arange(kernel_size) - half_size
|
48 |
+
if cutoff == 0:
|
49 |
+
filter_ = torch.zeros_like(time)
|
50 |
+
else:
|
51 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
52 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
53 |
+
# of the constant component in the input signal.
|
54 |
+
filter_ /= filter_.sum()
|
55 |
+
filter = filter_.view(1, 1, kernel_size)
|
56 |
+
|
57 |
+
return filter
|
58 |
+
|
59 |
+
|
60 |
+
class LowPassFilter1d(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
cutoff=0.5,
|
63 |
+
half_width=0.6,
|
64 |
+
stride: int = 1,
|
65 |
+
padding: bool = True,
|
66 |
+
padding_mode: str = 'replicate',
|
67 |
+
kernel_size: int = 12):
|
68 |
+
# kernel_size should be even number for stylegan3 setup,
|
69 |
+
# in this implementation, odd number is also possible.
|
70 |
+
super().__init__()
|
71 |
+
if cutoff < -0.:
|
72 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
73 |
+
if cutoff > 0.5:
|
74 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.even = (kernel_size % 2 == 0)
|
77 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
78 |
+
self.pad_right = kernel_size // 2
|
79 |
+
self.stride = stride
|
80 |
+
self.padding = padding
|
81 |
+
self.padding_mode = padding_mode
|
82 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
83 |
+
self.register_buffer("filter", filter)
|
84 |
+
|
85 |
+
#input [B, C, T]
|
86 |
+
def forward(self, x):
|
87 |
+
_, C, _ = x.shape
|
88 |
+
|
89 |
+
if self.padding:
|
90 |
+
x = F.pad(x, (self.pad_left, self.pad_right),
|
91 |
+
mode=self.padding_mode)
|
92 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
93 |
+
stride=self.stride, groups=C)
|
94 |
+
|
95 |
+
return out
|
alias_free_torch/resample.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .filter import LowPassFilter1d
|
7 |
+
from .filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
self.stride = ratio
|
16 |
+
self.pad = self.kernel_size // ratio - 1
|
17 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
18 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
19 |
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
20 |
+
half_width=0.6 / ratio,
|
21 |
+
kernel_size=self.kernel_size)
|
22 |
+
self.register_buffer("filter", filter)
|
23 |
+
|
24 |
+
# x: [B, C, T]
|
25 |
+
def forward(self, x):
|
26 |
+
_, C, _ = x.shape
|
27 |
+
|
28 |
+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
29 |
+
x = self.ratio * F.conv_transpose1d(
|
30 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
31 |
+
x = x[..., self.pad_left:-self.pad_right]
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class DownSample1d(nn.Module):
|
37 |
+
def __init__(self, ratio=2, kernel_size=None):
|
38 |
+
super().__init__()
|
39 |
+
self.ratio = ratio
|
40 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
41 |
+
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
42 |
+
half_width=0.6 / ratio,
|
43 |
+
stride=ratio,
|
44 |
+
kernel_size=self.kernel_size)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
xx = self.lowpass(x)
|
48 |
+
|
49 |
+
return xx
|
app.py.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from scipy.io.wavfile import write
|
6 |
+
import torchaudio
|
7 |
+
import utils
|
8 |
+
from Mels_preprocess import MelSpectrogramFixed
|
9 |
+
|
10 |
+
from hierspeechpp_speechsynthesizer import (
|
11 |
+
SynthesizerTrn
|
12 |
+
)
|
13 |
+
from ttv_v1.text import text_to_sequence
|
14 |
+
from ttv_v1.t2w2v_transformer import SynthesizerTrn as Text2W2V
|
15 |
+
from speechsr24k.speechsr import SynthesizerTrn as SpeechSR24
|
16 |
+
from speechsr48k.speechsr import SynthesizerTrn as SpeechSR48
|
17 |
+
from denoiser.generator import MPNet
|
18 |
+
from denoiser.infer import denoise
|
19 |
+
|
20 |
+
import gradio as gr
|
21 |
+
|
22 |
+
def load_text(fp):
|
23 |
+
with open(fp, 'r') as f:
|
24 |
+
filelist = [line.strip() for line in f.readlines()]
|
25 |
+
return filelist
|
26 |
+
def load_checkpoint(filepath, device):
|
27 |
+
print(filepath)
|
28 |
+
assert os.path.isfile(filepath)
|
29 |
+
print("Loading '{}'".format(filepath))
|
30 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
31 |
+
print("Complete.")
|
32 |
+
return checkpoint_dict
|
33 |
+
def get_param_num(model):
|
34 |
+
num_param = sum(param.numel() for param in model.parameters())
|
35 |
+
return num_param
|
36 |
+
def intersperse(lst, item):
|
37 |
+
result = [item] * (len(lst) * 2 + 1)
|
38 |
+
result[1::2] = lst
|
39 |
+
return result
|
40 |
+
def add_blank_token(text):
|
41 |
+
|
42 |
+
text_norm = intersperse(text, 0)
|
43 |
+
text_norm = torch.LongTensor(text_norm)
|
44 |
+
return text_norm
|
45 |
+
|
46 |
+
def tts(text,
|
47 |
+
prompt,
|
48 |
+
ttv_temperature,
|
49 |
+
vc_temperature,
|
50 |
+
duratuion_temperature,
|
51 |
+
duratuion_length,
|
52 |
+
denoise_ratio,
|
53 |
+
random_seed):
|
54 |
+
|
55 |
+
torch.manual_seed(random_seed)
|
56 |
+
torch.cuda.manual_seed(random_seed)
|
57 |
+
np.random.seed(random_seed)
|
58 |
+
|
59 |
+
text_len = len(text)
|
60 |
+
if text_len > 200:
|
61 |
+
raise gr.Error("Text length limited to 200 characters for this demo. Current text length is " + str(text_len))
|
62 |
+
|
63 |
+
else:
|
64 |
+
text = text_to_sequence(str(text), ["english_cleaners2"])
|
65 |
+
|
66 |
+
token = add_blank_token(text).unsqueeze(0).cuda()
|
67 |
+
token_length = torch.LongTensor([token.size(-1)]).cuda()
|
68 |
+
|
69 |
+
# Prompt load
|
70 |
+
# sample_rate, audio = prompt
|
71 |
+
# audio = torch.FloatTensor([audio]).cuda()
|
72 |
+
# if audio.shape[0] != 1:
|
73 |
+
# audio = audio[:1,:]
|
74 |
+
# audio = audio / 32768
|
75 |
+
audio, sample_rate = torchaudio.load(prompt)
|
76 |
+
|
77 |
+
# support only single channel
|
78 |
+
|
79 |
+
# Resampling
|
80 |
+
if sample_rate != 16000:
|
81 |
+
audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window")
|
82 |
+
|
83 |
+
# We utilize a hop size of 320 but denoiser uses a hop size of 400 so we utilize a hop size of 1600
|
84 |
+
ori_prompt_len = audio.shape[-1]
|
85 |
+
p = (ori_prompt_len // 1600 + 1) * 1600 - ori_prompt_len
|
86 |
+
audio = torch.nn.functional.pad(audio, (0, p), mode='constant').data
|
87 |
+
|
88 |
+
# If you have a memory issue during denosing the prompt, try to denoise the prompt with cpu before TTS
|
89 |
+
# We will have a plan to replace a memory-efficient denoiser
|
90 |
+
if denoise == 0:
|
91 |
+
audio = torch.cat([audio.cuda(), audio.cuda()], dim=0)
|
92 |
+
else:
|
93 |
+
with torch.no_grad():
|
94 |
+
|
95 |
+
if ori_prompt_len > 80000:
|
96 |
+
denoised_audio = []
|
97 |
+
for i in range((ori_prompt_len//80000)):
|
98 |
+
denoised_audio.append(denoise(audio.squeeze(0).cuda()[i*80000:(i+1)*80000], denoiser, hps_denoiser))
|
99 |
+
|
100 |
+
denoised_audio.append(denoise(audio.squeeze(0).cuda()[(i+1)*80000:], denoiser, hps_denoiser))
|
101 |
+
denoised_audio = torch.cat(denoised_audio, dim=1)
|
102 |
+
else:
|
103 |
+
denoised_audio = denoise(audio.squeeze(0).cuda(), denoiser, hps_denoiser)
|
104 |
+
|
105 |
+
audio = torch.cat([audio.cuda(), denoised_audio[:,:audio.shape[-1]]], dim=0)
|
106 |
+
|
107 |
+
audio = audio[:,:ori_prompt_len] # 20231108 We found that large size of padding decreases a performance so we remove the paddings after denosing.
|
108 |
+
|
109 |
+
if audio.shape[-1]<48000:
|
110 |
+
audio = torch.cat([audio,audio,audio,audio,audio], dim=1)
|
111 |
+
|
112 |
+
src_mel = mel_fn(audio.cuda())
|
113 |
+
|
114 |
+
src_length = torch.LongTensor([src_mel.size(2)]).to(device)
|
115 |
+
src_length2 = torch.cat([src_length,src_length], dim=0)
|
116 |
+
|
117 |
+
## TTV (Text --> W2V, F0)
|
118 |
+
with torch.no_grad():
|
119 |
+
w2v_x, pitch = text2w2v.infer_noise_control(token, token_length, src_mel, src_length2,
|
120 |
+
noise_scale=ttv_temperature, noise_scale_w=duratuion_temperature,
|
121 |
+
length_scale=duratuion_length, denoise_ratio=denoise_ratio)
|
122 |
+
src_length = torch.LongTensor([w2v_x.size(2)]).cuda()
|
123 |
+
|
124 |
+
pitch[pitch<torch.log(torch.tensor([55]).cuda())] = 0
|
125 |
+
|
126 |
+
## Hierarchical Speech Synthesizer (W2V, F0 --> 16k Audio)
|
127 |
+
converted_audio = \
|
128 |
+
net_g.voice_conversion_noise_control(w2v_x, src_length, src_mel, src_length2, pitch, noise_scale=vc_temperature, denoise_ratio=denoise_ratio)
|
129 |
+
|
130 |
+
converted_audio = speechsr(converted_audio)
|
131 |
+
|
132 |
+
converted_audio = converted_audio.squeeze()
|
133 |
+
|
134 |
+
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * 0.999
|
135 |
+
converted_audio = converted_audio.cpu().numpy().astype('int16')
|
136 |
+
|
137 |
+
write('output.wav', 48000, converted_audio)
|
138 |
+
return 'output.wav'
|
139 |
+
|
140 |
+
def main():
|
141 |
+
print('Initializing Inference Process..')
|
142 |
+
|
143 |
+
parser = argparse.ArgumentParser()
|
144 |
+
parser.add_argument('--input_prompt', default='example/steve-jobs-2005.wav')
|
145 |
+
parser.add_argument('--input_txt', default='example/abstract.txt')
|
146 |
+
parser.add_argument('--output_dir', default='output')
|
147 |
+
parser.add_argument('--ckpt', default='./logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth')
|
148 |
+
parser.add_argument('--ckpt_text2w2v', '-ct', help='text2w2v checkpoint path', default='./logs/ttv_libritts_v1/ttv_lt960_ckpt.pth')
|
149 |
+
parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth')
|
150 |
+
parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth')
|
151 |
+
parser.add_argument('--denoiser_ckpt', type=str, default='denoiser/g_best')
|
152 |
+
parser.add_argument('--scale_norm', type=str, default='max')
|
153 |
+
parser.add_argument('--output_sr', type=float, default=48000)
|
154 |
+
parser.add_argument('--noise_scale_ttv', type=float,
|
155 |
+
default=0.333)
|
156 |
+
parser.add_argument('--noise_scale_vc', type=float,
|
157 |
+
default=0.333)
|
158 |
+
parser.add_argument('--denoise_ratio', type=float,
|
159 |
+
default=0.8)
|
160 |
+
parser.add_argument('--duration_ratio', type=float,
|
161 |
+
default=0.8)
|
162 |
+
parser.add_argument('--seed', type=int,
|
163 |
+
default=1111)
|
164 |
+
a = parser.parse_args()
|
165 |
+
|
166 |
+
global device, hps, hps_t2w2v,h_sr,h_sr48, hps_denoiser
|
167 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
168 |
+
|
169 |
+
hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json'))
|
170 |
+
hps_t2w2v = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_text2w2v)[0], 'config.json'))
|
171 |
+
h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') )
|
172 |
+
h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') )
|
173 |
+
hps_denoiser = utils.get_hparams_from_file(os.path.join(os.path.split(a.denoiser_ckpt)[0], 'config.json'))
|
174 |
+
|
175 |
+
global mel_fn, net_g, text2w2v, speechsr, denoiser
|
176 |
+
|
177 |
+
mel_fn = MelSpectrogramFixed(
|
178 |
+
sample_rate=hps.data.sampling_rate,
|
179 |
+
n_fft=hps.data.filter_length,
|
180 |
+
win_length=hps.data.win_length,
|
181 |
+
hop_length=hps.data.hop_length,
|
182 |
+
f_min=hps.data.mel_fmin,
|
183 |
+
f_max=hps.data.mel_fmax,
|
184 |
+
n_mels=hps.data.n_mel_channels,
|
185 |
+
window_fn=torch.hann_window
|
186 |
+
).cuda()
|
187 |
+
|
188 |
+
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
|
189 |
+
hps.train.segment_size // hps.data.hop_length,
|
190 |
+
**hps.model).cuda()
|
191 |
+
net_g.load_state_dict(torch.load(a.ckpt))
|
192 |
+
_ = net_g.eval()
|
193 |
+
|
194 |
+
text2w2v = Text2W2V(hps.data.filter_length // 2 + 1,
|
195 |
+
hps.train.segment_size // hps.data.hop_length,
|
196 |
+
**hps_t2w2v.model).cuda()
|
197 |
+
text2w2v.load_state_dict(torch.load(a.ckpt_text2w2v))
|
198 |
+
text2w2v.eval()
|
199 |
+
|
200 |
+
speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
|
201 |
+
h_sr48.train.segment_size // h_sr48.data.hop_length,
|
202 |
+
**h_sr48.model).cuda()
|
203 |
+
utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
|
204 |
+
speechsr.eval()
|
205 |
+
|
206 |
+
denoiser = MPNet(hps_denoiser).cuda()
|
207 |
+
state_dict = load_checkpoint(a.denoiser_ckpt, device)
|
208 |
+
denoiser.load_state_dict(state_dict['generator'])
|
209 |
+
denoiser.eval()
|
210 |
+
|
211 |
+
demo_play = gr.Interface(fn = tts,
|
212 |
+
inputs = [gr.Textbox(max_lines=6, label="Input Text", value="HierSpeech is a zero shot speech synthesis model, which can generate high-quality audio", info="Up to 200 characters"),
|
213 |
+
gr.Audio(type='filepath', value="./example/3_rick_gt.wav"),
|
214 |
+
gr.Slider(0,1,0.333),
|
215 |
+
gr.Slider(0,1,0.333),
|
216 |
+
gr.Slider(0,1,1.0),
|
217 |
+
gr.Slider(0.5,2,1.0),
|
218 |
+
gr.Slider(0,1,0),
|
219 |
+
gr.Slider(0,9999,1111)],
|
220 |
+
outputs = 'audio',
|
221 |
+
title = 'HierSpeech++',
|
222 |
+
description = '''<div>
|
223 |
+
<p style="text-align: left"> HierSpeech++ is a zero-shot speech synthesis model.</p>
|
224 |
+
<p style="text-align: left"> Our model is trained with LibriTTS dataset so this model only supports english. We will release a multi-lingual HierSpeech++ soon.</p>
|
225 |
+
<p style="text-align: left"> <a href="https://sh-lee-prml.github.io/HierSpeechpp-demo/">[Demo Page]</a> <a href="https://github.com/sh-lee-prml/HierSpeechpp">[Source Code]</a></p>
|
226 |
+
</div>''',
|
227 |
+
examples=[["HierSpeech is a zero-shot speech synthesis model, which can generate high-quality audio", "./example/3_rick_gt.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
|
228 |
+
["HierSpeech is a zero-shot speech synthesis model, which can generate high-quality audio", "./example/ex01_whisper_00359.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
|
229 |
+
["Hi there, I'm your new voice clone. Try your best to upload quality audio", "./example/female.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
|
230 |
+
["Hello I'm HierSpeech++", "./example/reference_1.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
|
231 |
+
]
|
232 |
+
)
|
233 |
+
demo_play.launch(share=True, server_port=8888)
|
234 |
+
|
235 |
+
if __name__ == '__main__':
|
236 |
+
main()
|
attentions.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
import commons
|
9 |
+
import modules
|
10 |
+
from modules import LayerNorm
|
11 |
+
|
12 |
+
|
13 |
+
class Encoder(nn.Module):
|
14 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,
|
15 |
+
**kwargs):
|
16 |
+
super().__init__()
|
17 |
+
self.hidden_channels = hidden_channels
|
18 |
+
self.filter_channels = filter_channels
|
19 |
+
self.n_heads = n_heads
|
20 |
+
self.n_layers = n_layers
|
21 |
+
self.kernel_size = kernel_size
|
22 |
+
self.p_dropout = p_dropout
|
23 |
+
self.window_size = window_size
|
24 |
+
|
25 |
+
self.drop = nn.Dropout(p_dropout)
|
26 |
+
self.attn_layers = nn.ModuleList()
|
27 |
+
self.norm_layers_1 = nn.ModuleList()
|
28 |
+
self.ffn_layers = nn.ModuleList()
|
29 |
+
self.norm_layers_2 = nn.ModuleList()
|
30 |
+
for i in range(self.n_layers):
|
31 |
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
|
32 |
+
window_size=window_size))
|
33 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
34 |
+
self.ffn_layers.append(
|
35 |
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
36 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
37 |
+
|
38 |
+
def forward(self, x, x_mask):
|
39 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
40 |
+
x = x * x_mask
|
41 |
+
for i in range(self.n_layers):
|
42 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
43 |
+
y = self.drop(y)
|
44 |
+
x = self.norm_layers_1[i](x + y)
|
45 |
+
|
46 |
+
y = self.ffn_layers[i](x, x_mask)
|
47 |
+
y = self.drop(y)
|
48 |
+
x = self.norm_layers_2[i](x + y)
|
49 |
+
x = x * x_mask
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class Decoder(nn.Module):
|
54 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
|
55 |
+
proximal_bias=False, proximal_init=True, **kwargs):
|
56 |
+
super().__init__()
|
57 |
+
self.hidden_channels = hidden_channels
|
58 |
+
self.filter_channels = filter_channels
|
59 |
+
self.n_heads = n_heads
|
60 |
+
self.n_layers = n_layers
|
61 |
+
self.kernel_size = kernel_size
|
62 |
+
self.p_dropout = p_dropout
|
63 |
+
self.proximal_bias = proximal_bias
|
64 |
+
self.proximal_init = proximal_init
|
65 |
+
|
66 |
+
self.drop = nn.Dropout(p_dropout)
|
67 |
+
self.self_attn_layers = nn.ModuleList()
|
68 |
+
self.norm_layers_0 = nn.ModuleList()
|
69 |
+
self.encdec_attn_layers = nn.ModuleList()
|
70 |
+
self.norm_layers_1 = nn.ModuleList()
|
71 |
+
self.ffn_layers = nn.ModuleList()
|
72 |
+
self.norm_layers_2 = nn.ModuleList()
|
73 |
+
for i in range(self.n_layers):
|
74 |
+
self.self_attn_layers.append(
|
75 |
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
|
76 |
+
proximal_bias=proximal_bias, proximal_init=proximal_init))
|
77 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
78 |
+
self.encdec_attn_layers.append(
|
79 |
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
80 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
81 |
+
self.ffn_layers.append(
|
82 |
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
|
83 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
84 |
+
|
85 |
+
def forward(self, x, x_mask, h, h_mask):
|
86 |
+
"""
|
87 |
+
x: decoder input
|
88 |
+
h: encoder output
|
89 |
+
"""
|
90 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
91 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
92 |
+
x = x * x_mask
|
93 |
+
for i in range(self.n_layers):
|
94 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
95 |
+
y = self.drop(y)
|
96 |
+
x = self.norm_layers_0[i](x + y)
|
97 |
+
|
98 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
99 |
+
y = self.drop(y)
|
100 |
+
x = self.norm_layers_1[i](x + y)
|
101 |
+
|
102 |
+
y = self.ffn_layers[i](x, x_mask)
|
103 |
+
y = self.drop(y)
|
104 |
+
x = self.norm_layers_2[i](x + y)
|
105 |
+
x = x * x_mask
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class MultiHeadAttention(nn.Module):
|
110 |
+
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True,
|
111 |
+
block_length=None, proximal_bias=False, proximal_init=False):
|
112 |
+
super().__init__()
|
113 |
+
assert channels % n_heads == 0
|
114 |
+
|
115 |
+
self.channels = channels
|
116 |
+
self.out_channels = out_channels
|
117 |
+
self.n_heads = n_heads
|
118 |
+
self.p_dropout = p_dropout
|
119 |
+
self.window_size = window_size
|
120 |
+
self.heads_share = heads_share
|
121 |
+
self.block_length = block_length
|
122 |
+
self.proximal_bias = proximal_bias
|
123 |
+
self.proximal_init = proximal_init
|
124 |
+
self.attn = None
|
125 |
+
|
126 |
+
self.k_channels = channels // n_heads
|
127 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
128 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
129 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
130 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
131 |
+
self.drop = nn.Dropout(p_dropout)
|
132 |
+
|
133 |
+
if window_size is not None:
|
134 |
+
n_heads_rel = 1 if heads_share else n_heads
|
135 |
+
rel_stddev = self.k_channels ** -0.5
|
136 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
137 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
138 |
+
|
139 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
140 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
141 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
142 |
+
if proximal_init:
|
143 |
+
with torch.no_grad():
|
144 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
145 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
146 |
+
|
147 |
+
def forward(self, x, c, attn_mask=None):
|
148 |
+
q = self.conv_q(x)
|
149 |
+
k = self.conv_k(c)
|
150 |
+
v = self.conv_v(c)
|
151 |
+
|
152 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
153 |
+
|
154 |
+
x = self.conv_o(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
def attention(self, query, key, value, mask=None):
|
158 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
159 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
160 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
161 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
162 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
163 |
+
|
164 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
165 |
+
if self.window_size is not None:
|
166 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
167 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
168 |
+
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
169 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
170 |
+
scores = scores + scores_local
|
171 |
+
if self.proximal_bias:
|
172 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
173 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
174 |
+
if mask is not None:
|
175 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
176 |
+
if self.block_length is not None:
|
177 |
+
assert t_s == t_t, "Local attention is only available for self-attention."
|
178 |
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
179 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
180 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
181 |
+
p_attn = self.drop(p_attn)
|
182 |
+
output = torch.matmul(p_attn, value)
|
183 |
+
if self.window_size is not None:
|
184 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
185 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
186 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
187 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
188 |
+
return output, p_attn
|
189 |
+
|
190 |
+
def _matmul_with_relative_values(self, x, y):
|
191 |
+
"""
|
192 |
+
x: [b, h, l, m]
|
193 |
+
y: [h or 1, m, d]
|
194 |
+
ret: [b, h, l, d]
|
195 |
+
"""
|
196 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
197 |
+
return ret
|
198 |
+
|
199 |
+
def _matmul_with_relative_keys(self, x, y):
|
200 |
+
"""
|
201 |
+
x: [b, h, l, d]
|
202 |
+
y: [h or 1, m, d]
|
203 |
+
ret: [b, h, l, m]
|
204 |
+
"""
|
205 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
206 |
+
return ret
|
207 |
+
|
208 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
209 |
+
max_relative_position = 2 * self.window_size + 1
|
210 |
+
# Pad first before slice to avoid using cond ops.
|
211 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
212 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
213 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
214 |
+
if pad_length > 0:
|
215 |
+
padded_relative_embeddings = F.pad(
|
216 |
+
relative_embeddings,
|
217 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
218 |
+
else:
|
219 |
+
padded_relative_embeddings = relative_embeddings
|
220 |
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
221 |
+
return used_relative_embeddings
|
222 |
+
|
223 |
+
def _relative_position_to_absolute_position(self, x):
|
224 |
+
"""
|
225 |
+
x: [b, h, l, 2*l-1]
|
226 |
+
ret: [b, h, l, l]
|
227 |
+
"""
|
228 |
+
batch, heads, length, _ = x.size()
|
229 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
230 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
231 |
+
|
232 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
233 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
234 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
235 |
+
|
236 |
+
# Reshape and slice out the padded elements.
|
237 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
|
238 |
+
return x_final
|
239 |
+
|
240 |
+
def _absolute_position_to_relative_position(self, x):
|
241 |
+
"""
|
242 |
+
x: [b, h, l, l]
|
243 |
+
ret: [b, h, l, 2*l-1]
|
244 |
+
"""
|
245 |
+
batch, heads, length, _ = x.size()
|
246 |
+
# padd along column
|
247 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
248 |
+
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
|
249 |
+
# add 0's in the beginning that will skew the elements after reshape
|
250 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
251 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
252 |
+
return x_final
|
253 |
+
|
254 |
+
def _attention_bias_proximal(self, length):
|
255 |
+
"""Bias for self-attention to encourage attention to close positions.
|
256 |
+
Args:
|
257 |
+
length: an integer scalar.
|
258 |
+
Returns:
|
259 |
+
a Tensor with shape [1, 1, length, length]
|
260 |
+
"""
|
261 |
+
r = torch.arange(length, dtype=torch.float32)
|
262 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
263 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
264 |
+
|
265 |
+
|
266 |
+
class FFN(nn.Module):
|
267 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None,
|
268 |
+
causal=False):
|
269 |
+
super().__init__()
|
270 |
+
self.in_channels = in_channels
|
271 |
+
self.out_channels = out_channels
|
272 |
+
self.filter_channels = filter_channels
|
273 |
+
self.kernel_size = kernel_size
|
274 |
+
self.p_dropout = p_dropout
|
275 |
+
self.activation = activation
|
276 |
+
self.causal = causal
|
277 |
+
|
278 |
+
if causal:
|
279 |
+
self.padding = self._causal_padding
|
280 |
+
else:
|
281 |
+
self.padding = self._same_padding
|
282 |
+
|
283 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
284 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
285 |
+
self.drop = nn.Dropout(p_dropout)
|
286 |
+
|
287 |
+
def forward(self, x, x_mask):
|
288 |
+
x = self.conv_1(self.padding(x * x_mask))
|
289 |
+
if self.activation == "gelu":
|
290 |
+
x = x * torch.sigmoid(1.702 * x)
|
291 |
+
else:
|
292 |
+
x = torch.relu(x)
|
293 |
+
x = self.drop(x)
|
294 |
+
x = self.conv_2(self.padding(x * x_mask))
|
295 |
+
return x * x_mask
|
296 |
+
|
297 |
+
def _causal_padding(self, x):
|
298 |
+
if self.kernel_size == 1:
|
299 |
+
return x
|
300 |
+
pad_l = self.kernel_size - 1
|
301 |
+
pad_r = 0
|
302 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
303 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
304 |
+
return x
|
305 |
+
|
306 |
+
def _same_padding(self, x):
|
307 |
+
if self.kernel_size == 1:
|
308 |
+
return x
|
309 |
+
pad_l = (self.kernel_size - 1) // 2
|
310 |
+
pad_r = self.kernel_size // 2
|
311 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
312 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
313 |
+
return x
|
commons.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def init_weights(m, mean=0.0, std=0.01):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
m.weight.data.normal_(mean, std)
|
12 |
+
|
13 |
+
|
14 |
+
def get_padding(kernel_size, dilation=1):
|
15 |
+
return int((kernel_size*dilation - dilation)/2)
|
16 |
+
|
17 |
+
|
18 |
+
def convert_pad_shape(pad_shape):
|
19 |
+
l = pad_shape[::-1]
|
20 |
+
pad_shape = [item for sublist in l for item in sublist]
|
21 |
+
return pad_shape
|
22 |
+
|
23 |
+
|
24 |
+
def intersperse(lst, item):
|
25 |
+
result = [item] * (len(lst) * 2 + 1)
|
26 |
+
result[1::2] = lst
|
27 |
+
return result
|
28 |
+
|
29 |
+
|
30 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
31 |
+
"""KL(P||Q)"""
|
32 |
+
kl = (logs_q - logs_p) - 0.5
|
33 |
+
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
|
34 |
+
return kl
|
35 |
+
|
36 |
+
|
37 |
+
def rand_gumbel(shape):
|
38 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
39 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
40 |
+
return -torch.log(-torch.log(uniform_samples))
|
41 |
+
|
42 |
+
|
43 |
+
def rand_gumbel_like(x):
|
44 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
45 |
+
return g
|
46 |
+
|
47 |
+
|
48 |
+
def slice_segments(x, ids_str, segment_size=4):
|
49 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
50 |
+
for i in range(x.size(0)):
|
51 |
+
idx_str = ids_str[i]
|
52 |
+
idx_end = idx_str + segment_size
|
53 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
54 |
+
return ret
|
55 |
+
|
56 |
+
def slice_segments_audio(x, ids_str, segment_size=4):
|
57 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
58 |
+
for i in range(x.size(0)):
|
59 |
+
idx_str = ids_str[i]
|
60 |
+
idx_end = idx_str + segment_size
|
61 |
+
ret[i] = x[i, idx_str:idx_end]
|
62 |
+
return ret
|
63 |
+
|
64 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
65 |
+
b, d, t = x.size()
|
66 |
+
if x_lengths is None:
|
67 |
+
x_lengths = t
|
68 |
+
ids_str_max = x_lengths - segment_size + 1
|
69 |
+
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(dtype=torch.long)
|
70 |
+
ret = slice_segments(x, ids_str, segment_size)
|
71 |
+
return ret, ids_str
|
72 |
+
|
73 |
+
|
74 |
+
def get_timing_signal_1d(
|
75 |
+
length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
76 |
+
position = torch.arange(length, dtype=torch.float)
|
77 |
+
num_timescales = channels // 2
|
78 |
+
log_timescale_increment = (
|
79 |
+
math.log(float(max_timescale) / float(min_timescale)) /
|
80 |
+
(num_timescales - 1))
|
81 |
+
inv_timescales = min_timescale * torch.exp(
|
82 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
83 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
84 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
85 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
86 |
+
signal = signal.view(1, channels, length)
|
87 |
+
return signal
|
88 |
+
|
89 |
+
|
90 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
91 |
+
b, channels, length = x.size()
|
92 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
93 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
94 |
+
|
95 |
+
|
96 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
97 |
+
b, channels, length = x.size()
|
98 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
99 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
100 |
+
|
101 |
+
|
102 |
+
def subsequent_mask(length):
|
103 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
104 |
+
return mask
|
105 |
+
|
106 |
+
|
107 |
+
@torch.jit.script
|
108 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
109 |
+
n_channels_int = n_channels[0]
|
110 |
+
in_act = input_a + input_b
|
111 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
112 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
113 |
+
acts = t_act * s_act
|
114 |
+
return acts
|
115 |
+
|
116 |
+
|
117 |
+
def convert_pad_shape(pad_shape):
|
118 |
+
l = pad_shape[::-1]
|
119 |
+
pad_shape = [item for sublist in l for item in sublist]
|
120 |
+
return pad_shape
|
121 |
+
|
122 |
+
|
123 |
+
def shift_1d(x):
|
124 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
def sequence_mask(length, max_length=None):
|
129 |
+
if max_length is None:
|
130 |
+
max_length = length.max()
|
131 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
132 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
133 |
+
|
134 |
+
|
135 |
+
def generate_path(duration, mask):
|
136 |
+
"""
|
137 |
+
duration: [b, 1, t_x]
|
138 |
+
mask: [b, 1, t_y, t_x]
|
139 |
+
"""
|
140 |
+
device = duration.device
|
141 |
+
|
142 |
+
b, _, t_y, t_x = mask.shape
|
143 |
+
cum_duration = torch.cumsum(duration, -1)
|
144 |
+
|
145 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
146 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
147 |
+
path = path.view(b, t_x, t_y)
|
148 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
149 |
+
path = path.unsqueeze(1).transpose(2,3) * mask
|
150 |
+
return path
|
151 |
+
|
152 |
+
|
153 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
154 |
+
if isinstance(parameters, torch.Tensor):
|
155 |
+
parameters = [parameters]
|
156 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
157 |
+
norm_type = float(norm_type)
|
158 |
+
if clip_value is not None:
|
159 |
+
clip_value = float(clip_value)
|
160 |
+
|
161 |
+
total_norm = 0
|
162 |
+
for p in parameters:
|
163 |
+
param_norm = p.grad.data.norm(norm_type)
|
164 |
+
total_norm += param_norm.item() ** norm_type
|
165 |
+
if clip_value is not None:
|
166 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
167 |
+
total_norm = total_norm ** (1. / norm_type)
|
168 |
+
return total_norm
|
denoiser/config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_gpus": 0,
|
3 |
+
"batch_size": 4,
|
4 |
+
"learning_rate": 0.0005,
|
5 |
+
"adam_b1": 0.8,
|
6 |
+
"adam_b2": 0.99,
|
7 |
+
"lr_decay": 0.99,
|
8 |
+
"seed": 1234,
|
9 |
+
|
10 |
+
"dense_channel": 64,
|
11 |
+
"compress_factor": 0.3,
|
12 |
+
"num_tsconformers": 4,
|
13 |
+
"beta": 2.0,
|
14 |
+
|
15 |
+
"sampling_rate": 16000,
|
16 |
+
"segment_size": 32000,
|
17 |
+
"n_fft": 400,
|
18 |
+
"hop_size": 100,
|
19 |
+
"win_size": 400,
|
20 |
+
|
21 |
+
"num_workers": 4,
|
22 |
+
|
23 |
+
"dist_config": {
|
24 |
+
"dist_backend": "nccl",
|
25 |
+
"dist_url": "tcp://localhost:54321",
|
26 |
+
"world_size": 1
|
27 |
+
}
|
28 |
+
}
|
denoiser/conformer.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops.layers.torch import Rearrange
|
4 |
+
|
5 |
+
def get_padding(kernel_size, dilation=1):
|
6 |
+
return int((kernel_size*dilation - dilation)/2)
|
7 |
+
|
8 |
+
class FeedForwardModule(nn.Module):
|
9 |
+
def __init__(self, dim, mult=4, dropout=0):
|
10 |
+
super(FeedForwardModule, self).__init__()
|
11 |
+
self.ffm = nn.Sequential(
|
12 |
+
nn.LayerNorm(dim),
|
13 |
+
nn.Linear(dim, dim * mult),
|
14 |
+
nn.SiLU(),
|
15 |
+
nn.Dropout(dropout),
|
16 |
+
nn.Linear(dim * mult, dim),
|
17 |
+
nn.Dropout(dropout)
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return self.ffm(x)
|
22 |
+
|
23 |
+
|
24 |
+
class ConformerConvModule(nn.Module):
|
25 |
+
def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0.):
|
26 |
+
super(ConformerConvModule, self).__init__()
|
27 |
+
inner_dim = dim * expansion_factor
|
28 |
+
self.ccm = nn.Sequential(
|
29 |
+
nn.LayerNorm(dim),
|
30 |
+
Rearrange('b n c -> b c n'),
|
31 |
+
nn.Conv1d(dim, inner_dim*2, 1),
|
32 |
+
nn.GLU(dim=1),
|
33 |
+
nn.Conv1d(inner_dim, inner_dim, kernel_size=kernel_size,
|
34 |
+
padding=get_padding(kernel_size), groups=inner_dim), # DepthWiseConv1d
|
35 |
+
nn.BatchNorm1d(inner_dim),
|
36 |
+
nn.SiLU(),
|
37 |
+
nn.Conv1d(inner_dim, dim, 1),
|
38 |
+
Rearrange('b c n -> b n c'),
|
39 |
+
nn.Dropout(dropout)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return self.ccm(x)
|
44 |
+
|
45 |
+
|
46 |
+
class AttentionModule(nn.Module):
|
47 |
+
def __init__(self, dim, n_head=8, dropout=0.):
|
48 |
+
super(AttentionModule, self).__init__()
|
49 |
+
self.attn = nn.MultiheadAttention(dim, n_head, dropout=dropout)
|
50 |
+
self.layernorm = nn.LayerNorm(dim)
|
51 |
+
|
52 |
+
def forward(self, x, attn_mask=None, key_padding_mask=None):
|
53 |
+
x = self.layernorm(x)
|
54 |
+
x, _ = self.attn(x, x, x,
|
55 |
+
attn_mask=attn_mask,
|
56 |
+
key_padding_mask=key_padding_mask)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class ConformerBlock(nn.Module):
|
61 |
+
def __init__(self, dim, n_head=8, ffm_mult=4, ccm_expansion_factor=2, ccm_kernel_size=31,
|
62 |
+
ffm_dropout=0., attn_dropout=0., ccm_dropout=0.):
|
63 |
+
super(ConformerBlock, self).__init__()
|
64 |
+
self.ffm1 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout)
|
65 |
+
self.attn = AttentionModule(dim, n_head, dropout=attn_dropout)
|
66 |
+
self.ccm = ConformerConvModule(dim, ccm_expansion_factor, ccm_kernel_size, dropout=ccm_dropout)
|
67 |
+
self.ffm2 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout)
|
68 |
+
self.post_norm = nn.LayerNorm(dim)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = x + 0.5 * self.ffm1(x)
|
72 |
+
x = x + self.attn(x)
|
73 |
+
x = x + self.ccm(x)
|
74 |
+
x = x + 0.5 * self.ffm2(x)
|
75 |
+
x = self.post_norm(x)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
def main():
|
80 |
+
x = torch.ones(10, 100, 64)
|
81 |
+
conformer = ConformerBlock(dim=64)
|
82 |
+
print(conformer(x))
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
main()
|
denoiser/g_best
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0088cd06068a97b97cc13fe10fe155ea5c24beea79564b2162fab22a79dc9dc5
|
3 |
+
size 8350488
|
denoiser/generator.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from denoiser.conformer import ConformerBlock
|
6 |
+
from denoiser.utils import get_padding_2d, LearnableSigmoid_2d
|
7 |
+
from pesq import pesq
|
8 |
+
from joblib import Parallel, delayed
|
9 |
+
|
10 |
+
class DenseBlock(nn.Module):
|
11 |
+
def __init__(self, h, kernel_size=(3, 3), depth=4):
|
12 |
+
super(DenseBlock, self).__init__()
|
13 |
+
self.h = h
|
14 |
+
self.depth = depth
|
15 |
+
self.dense_block = nn.ModuleList([])
|
16 |
+
for i in range(depth):
|
17 |
+
dil = 2 ** i
|
18 |
+
dense_conv = nn.Sequential(
|
19 |
+
nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dil, 1),
|
20 |
+
padding=get_padding_2d(kernel_size, (dil, 1))),
|
21 |
+
nn.InstanceNorm2d(h.dense_channel, affine=True),
|
22 |
+
nn.PReLU(h.dense_channel)
|
23 |
+
)
|
24 |
+
self.dense_block.append(dense_conv)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
skip = x
|
28 |
+
for i in range(self.depth):
|
29 |
+
x = self.dense_block[i](skip)
|
30 |
+
skip = torch.cat([x, skip], dim=1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class DenseEncoder(nn.Module):
|
35 |
+
def __init__(self, h, in_channel):
|
36 |
+
super(DenseEncoder, self).__init__()
|
37 |
+
self.h = h
|
38 |
+
self.dense_conv_1 = nn.Sequential(
|
39 |
+
nn.Conv2d(in_channel, h.dense_channel, (1, 1)),
|
40 |
+
nn.InstanceNorm2d(h.dense_channel, affine=True),
|
41 |
+
nn.PReLU(h.dense_channel))
|
42 |
+
|
43 |
+
self.dense_block = DenseBlock(h, depth=4) # [b, h.dense_channel, ndim_time, h.n_fft//2+1]
|
44 |
+
|
45 |
+
self.dense_conv_2 = nn.Sequential(
|
46 |
+
nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2)),
|
47 |
+
nn.InstanceNorm2d(h.dense_channel, affine=True),
|
48 |
+
nn.PReLU(h.dense_channel))
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.dense_conv_1(x) # [b, 64, T, F]
|
52 |
+
x = self.dense_block(x) # [b, 64, T, F]
|
53 |
+
x = self.dense_conv_2(x) # [b, 64, T, F//2]
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class MaskDecoder(nn.Module):
|
58 |
+
def __init__(self, h, out_channel=1):
|
59 |
+
super(MaskDecoder, self).__init__()
|
60 |
+
self.dense_block = DenseBlock(h, depth=4)
|
61 |
+
self.mask_conv = nn.Sequential(
|
62 |
+
nn.ConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2)),
|
63 |
+
nn.Conv2d(h.dense_channel, out_channel, (1, 1)),
|
64 |
+
nn.InstanceNorm2d(out_channel, affine=True),
|
65 |
+
nn.PReLU(out_channel),
|
66 |
+
nn.Conv2d(out_channel, out_channel, (1, 1))
|
67 |
+
)
|
68 |
+
self.lsigmoid = LearnableSigmoid_2d(h.n_fft//2+1, beta=h.beta)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = self.dense_block(x)
|
72 |
+
x = self.mask_conv(x)
|
73 |
+
x = x.permute(0, 3, 2, 1).squeeze(-1)
|
74 |
+
x = self.lsigmoid(x).permute(0, 2, 1).unsqueeze(1)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class PhaseDecoder(nn.Module):
|
79 |
+
def __init__(self, h, out_channel=1):
|
80 |
+
super(PhaseDecoder, self).__init__()
|
81 |
+
self.dense_block = DenseBlock(h, depth=4)
|
82 |
+
self.phase_conv = nn.Sequential(
|
83 |
+
nn.ConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2)),
|
84 |
+
nn.InstanceNorm2d(h.dense_channel, affine=True),
|
85 |
+
nn.PReLU(h.dense_channel)
|
86 |
+
)
|
87 |
+
self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 1))
|
88 |
+
self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 1))
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x = self.dense_block(x)
|
92 |
+
x = self.phase_conv(x)
|
93 |
+
x_r = self.phase_conv_r(x)
|
94 |
+
x_i = self.phase_conv_i(x)
|
95 |
+
x = torch.atan2(x_i, x_r)
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
class TSConformerBlock(nn.Module):
|
100 |
+
def __init__(self, h):
|
101 |
+
super(TSConformerBlock, self).__init__()
|
102 |
+
self.h = h
|
103 |
+
self.time_conformer = ConformerBlock(dim=h.dense_channel, n_head=4, ccm_kernel_size=31,
|
104 |
+
ffm_dropout=0.2, attn_dropout=0.2)
|
105 |
+
self.freq_conformer = ConformerBlock(dim=h.dense_channel, n_head=4, ccm_kernel_size=31,
|
106 |
+
ffm_dropout=0.2, attn_dropout=0.2)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
b, c, t, f = x.size()
|
110 |
+
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
|
111 |
+
x = self.time_conformer(x) + x
|
112 |
+
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
|
113 |
+
x = self.freq_conformer(x) + x
|
114 |
+
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
class MPNet(nn.Module):
|
119 |
+
def __init__(self, h, num_tscblocks=4):
|
120 |
+
super(MPNet, self).__init__()
|
121 |
+
self.h = h
|
122 |
+
self.num_tscblocks = num_tscblocks
|
123 |
+
self.dense_encoder = DenseEncoder(h, in_channel=2)
|
124 |
+
|
125 |
+
self.TSConformer = nn.ModuleList([])
|
126 |
+
for i in range(num_tscblocks):
|
127 |
+
self.TSConformer.append(TSConformerBlock(h))
|
128 |
+
|
129 |
+
self.mask_decoder = MaskDecoder(h, out_channel=1)
|
130 |
+
self.phase_decoder = PhaseDecoder(h, out_channel=1)
|
131 |
+
|
132 |
+
def forward(self, noisy_mag, noisy_pha): # [B, F, T]
|
133 |
+
noisy_mag = noisy_mag.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
|
134 |
+
noisy_pha = noisy_pha.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F]
|
135 |
+
x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F]
|
136 |
+
x = self.dense_encoder(x)
|
137 |
+
|
138 |
+
for i in range(self.num_tscblocks):
|
139 |
+
x = self.TSConformer[i](x)
|
140 |
+
|
141 |
+
denoised_mag = (noisy_mag * self.mask_decoder(x)).permute(0, 3, 2, 1).squeeze(-1)
|
142 |
+
denoised_pha = self.phase_decoder(x).permute(0, 3, 2, 1).squeeze(-1)
|
143 |
+
denoised_com = torch.stack((denoised_mag*torch.cos(denoised_pha),
|
144 |
+
denoised_mag*torch.sin(denoised_pha)), dim=-1)
|
145 |
+
|
146 |
+
return denoised_mag, denoised_pha, denoised_com
|
147 |
+
|
148 |
+
|
149 |
+
def phase_losses(phase_r, phase_g, h):
|
150 |
+
|
151 |
+
dim_freq = h.n_fft // 2 + 1
|
152 |
+
dim_time = phase_r.size(-1)
|
153 |
+
|
154 |
+
gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) - torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) - torch.eye(dim_freq)).to(phase_g.device)
|
155 |
+
gd_r = torch.matmul(phase_r.permute(0, 2, 1), gd_matrix)
|
156 |
+
gd_g = torch.matmul(phase_g.permute(0, 2, 1), gd_matrix)
|
157 |
+
|
158 |
+
iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) - torch.triu(torch.ones(dim_time, dim_time), diagonal=2) - torch.eye(dim_time)).to(phase_g.device)
|
159 |
+
iaf_r = torch.matmul(phase_r, iaf_matrix)
|
160 |
+
iaf_g = torch.matmul(phase_g, iaf_matrix)
|
161 |
+
|
162 |
+
ip_loss = torch.mean(anti_wrapping_function(phase_r-phase_g))
|
163 |
+
gd_loss = torch.mean(anti_wrapping_function(gd_r-gd_g))
|
164 |
+
iaf_loss = torch.mean(anti_wrapping_function(iaf_r-iaf_g))
|
165 |
+
|
166 |
+
return ip_loss, gd_loss, iaf_loss
|
167 |
+
|
168 |
+
|
169 |
+
def anti_wrapping_function(x):
|
170 |
+
|
171 |
+
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
172 |
+
|
173 |
+
|
174 |
+
def pesq_score(utts_r, utts_g, h):
|
175 |
+
|
176 |
+
pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
|
177 |
+
utts_r[i].squeeze().cpu().numpy(),
|
178 |
+
utts_g[i].squeeze().cpu().numpy(),
|
179 |
+
h.sampling_rate)
|
180 |
+
for i in range(len(utts_r)))
|
181 |
+
pesq_score = np.mean(pesq_score)
|
182 |
+
|
183 |
+
return pesq_score
|
184 |
+
|
185 |
+
|
186 |
+
def eval_pesq(clean_utt, esti_utt, sr):
|
187 |
+
try:
|
188 |
+
pesq_score = pesq(sr, clean_utt, esti_utt)
|
189 |
+
except:
|
190 |
+
# error can happen due to silent period
|
191 |
+
pesq_score = -1
|
192 |
+
|
193 |
+
return pesq_score
|
denoiser/infer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def denoise(noisy_wav, model, hps):
|
5 |
+
norm_factor = torch.sqrt(len(noisy_wav) / torch.sum(noisy_wav ** 2.0)).to(noisy_wav.device)
|
6 |
+
noisy_wav = (noisy_wav * norm_factor).unsqueeze(0)
|
7 |
+
noisy_amp, noisy_pha, noisy_com = mag_pha_stft(noisy_wav, hps.n_fft, hps.hop_size, hps.win_size, hps.compress_factor)
|
8 |
+
amp_g, pha_g, com_g = model(noisy_amp, noisy_pha)
|
9 |
+
audio_g = mag_pha_istft(amp_g, pha_g, hps.n_fft, hps.hop_size, hps.win_size, hps.compress_factor)
|
10 |
+
audio_g = audio_g / norm_factor
|
11 |
+
return audio_g
|
12 |
+
|
13 |
+
def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
14 |
+
|
15 |
+
hann_window = torch.hann_window(win_size).to(y.device)
|
16 |
+
stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
|
17 |
+
center=center, pad_mode='reflect', normalized=False, return_complex=True)
|
18 |
+
mag = torch.abs(stft_spec)
|
19 |
+
pha = torch.angle(stft_spec)
|
20 |
+
# Magnitude Compression
|
21 |
+
mag = torch.pow(mag, compress_factor)
|
22 |
+
com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
|
23 |
+
|
24 |
+
return mag, pha, com
|
25 |
+
|
26 |
+
def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
27 |
+
# Magnitude Decompression
|
28 |
+
mag = torch.pow(mag, (1.0/compress_factor))
|
29 |
+
com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
|
30 |
+
hann_window = torch.hann_window(win_size).to(com.device)
|
31 |
+
wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
|
32 |
+
|
33 |
+
return wav
|
denoiser/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
def get_padding(kernel_size, dilation=1):
|
7 |
+
return int((kernel_size*dilation - dilation)/2)
|
8 |
+
|
9 |
+
|
10 |
+
def get_padding_2d(kernel_size, dilation=(1, 1)):
|
11 |
+
return (int((kernel_size[0]*dilation[0] - dilation[0])/2), int((kernel_size[1]*dilation[1] - dilation[1])/2))
|
12 |
+
|
13 |
+
|
14 |
+
def load_checkpoint(filepath, device):
|
15 |
+
assert os.path.isfile(filepath)
|
16 |
+
print("Loading '{}'".format(filepath))
|
17 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
18 |
+
print("Complete.")
|
19 |
+
return checkpoint_dict
|
20 |
+
|
21 |
+
|
22 |
+
def save_checkpoint(filepath, obj):
|
23 |
+
print("Saving checkpoint to {}".format(filepath))
|
24 |
+
torch.save(obj, filepath)
|
25 |
+
print("Complete.")
|
26 |
+
|
27 |
+
|
28 |
+
def scan_checkpoint(cp_dir, prefix):
|
29 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
30 |
+
cp_list = glob.glob(pattern)
|
31 |
+
if len(cp_list) == 0:
|
32 |
+
return None
|
33 |
+
return sorted(cp_list)[-1]
|
34 |
+
|
35 |
+
|
36 |
+
class LearnableSigmoid_1d(nn.Module):
|
37 |
+
def __init__(self, in_features, beta=1):
|
38 |
+
super().__init__()
|
39 |
+
self.beta = beta
|
40 |
+
self.slope = nn.Parameter(torch.ones(in_features))
|
41 |
+
self.slope.requiresGrad = True
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.beta * torch.sigmoid(self.slope * x)
|
45 |
+
|
46 |
+
|
47 |
+
class LearnableSigmoid_2d(nn.Module):
|
48 |
+
def __init__(self, in_features, beta=1):
|
49 |
+
super().__init__()
|
50 |
+
self.beta = beta
|
51 |
+
self.slope = nn.Parameter(torch.ones(in_features, 1))
|
52 |
+
self.slope.requiresGrad = True
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.beta * torch.sigmoid(self.slope * x)
|
example/reference_1.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
And lay me down in my cold bed and leave my shining lot.
|
example/reference_1.wav
ADDED
Binary file (96 kB). View file
|
|
example/reference_2.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech.
|
example/reference_2.wav
ADDED
Binary file (96 kB). View file
|
|
example/reference_3.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The army found the people in poverty and left them in comparative wealth.
|
example/reference_3.wav
ADDED
Binary file (96 kB). View file
|
|
example/reference_4.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Thus did this humane and right minded father comfort his unhappy daughter, and her mother embracing her again, did all she could to soothe her feelings.
|
example/reference_4.wav
ADDED
Binary file (96 kB). View file
|
|
hierspeechpp_speechsynthesizer.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import modules
|
5 |
+
import attentions
|
6 |
+
|
7 |
+
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
8 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
9 |
+
from commons import init_weights, get_padding
|
10 |
+
|
11 |
+
import torchaudio
|
12 |
+
from einops import rearrange
|
13 |
+
import transformers
|
14 |
+
import math
|
15 |
+
from styleencoder import StyleEncoder
|
16 |
+
import commons
|
17 |
+
|
18 |
+
from alias_free_torch import *
|
19 |
+
import activations
|
20 |
+
|
21 |
+
class Wav2vec2(torch.nn.Module):
|
22 |
+
def __init__(self, layer=7, w2v='mms'):
|
23 |
+
|
24 |
+
"""we use the intermediate features of mms-300m.
|
25 |
+
More specifically, we used the output from the 7th layer of the 24-layer transformer encoder.
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
if w2v == 'mms':
|
30 |
+
self.wav2vec2 = transformers.Wav2Vec2ForPreTraining.from_pretrained("facebook/mms-300m")
|
31 |
+
else:
|
32 |
+
self.wav2vec2 = transformers.Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-xls-r-300m")
|
33 |
+
|
34 |
+
for param in self.wav2vec2.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
param.grad = None
|
37 |
+
self.wav2vec2.eval()
|
38 |
+
self.feature_layer = layer
|
39 |
+
|
40 |
+
@torch.no_grad()
|
41 |
+
def forward(self, x):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
x: torch.Tensor of shape (B x t)
|
45 |
+
Returns:
|
46 |
+
y: torch.Tensor of shape(B x C x t)
|
47 |
+
"""
|
48 |
+
outputs = self.wav2vec2(x.squeeze(1), output_hidden_states=True)
|
49 |
+
y = outputs.hidden_states[self.feature_layer] # B x t x C(1024)
|
50 |
+
y = y.permute((0, 2, 1)) # B x t x C -> B x C x t
|
51 |
+
return y
|
52 |
+
|
53 |
+
class ResidualCouplingBlock_Transformer(nn.Module):
|
54 |
+
def __init__(self,
|
55 |
+
channels,
|
56 |
+
hidden_channels,
|
57 |
+
kernel_size,
|
58 |
+
dilation_rate,
|
59 |
+
n_layers=3,
|
60 |
+
n_flows=4,
|
61 |
+
gin_channels=0):
|
62 |
+
super().__init__()
|
63 |
+
self.channels = channels
|
64 |
+
self.hidden_channels = hidden_channels
|
65 |
+
self.kernel_size = kernel_size
|
66 |
+
self.dilation_rate = dilation_rate
|
67 |
+
self.n_layers = n_layers
|
68 |
+
self.n_flows = n_flows
|
69 |
+
self.gin_channels = gin_channels
|
70 |
+
self.cond_block = torch.nn.Sequential(torch.nn.Linear(gin_channels, 4 * hidden_channels),
|
71 |
+
nn.SiLU(), torch.nn.Linear(4 * hidden_channels, hidden_channels))
|
72 |
+
|
73 |
+
self.flows = nn.ModuleList()
|
74 |
+
for i in range(n_flows):
|
75 |
+
self.flows.append(modules.ResidualCouplingLayer_Transformer_simple(channels, hidden_channels, kernel_size, dilation_rate, n_layers, mean_only=True))
|
76 |
+
self.flows.append(modules.Flip())
|
77 |
+
|
78 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
79 |
+
|
80 |
+
g = self.cond_block(g.squeeze(2))
|
81 |
+
|
82 |
+
if not reverse:
|
83 |
+
for flow in self.flows:
|
84 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
85 |
+
else:
|
86 |
+
for flow in reversed(self.flows):
|
87 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
88 |
+
return x
|
89 |
+
|
90 |
+
class PosteriorAudioEncoder(nn.Module):
|
91 |
+
def __init__(self,
|
92 |
+
in_channels,
|
93 |
+
out_channels,
|
94 |
+
hidden_channels,
|
95 |
+
kernel_size,
|
96 |
+
dilation_rate,
|
97 |
+
n_layers,
|
98 |
+
gin_channels=0):
|
99 |
+
super().__init__()
|
100 |
+
self.in_channels = in_channels
|
101 |
+
self.out_channels = out_channels
|
102 |
+
self.hidden_channels = hidden_channels
|
103 |
+
self.kernel_size = kernel_size
|
104 |
+
self.dilation_rate = dilation_rate
|
105 |
+
self.n_layers = n_layers
|
106 |
+
self.gin_channels = gin_channels
|
107 |
+
|
108 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
109 |
+
self.down_pre = nn.Conv1d(1, 16, 7, 1, padding=3)
|
110 |
+
self.resblocks = nn.ModuleList()
|
111 |
+
downsample_rates = [8,5,4,2]
|
112 |
+
downsample_kernel_sizes = [17, 10, 8, 4]
|
113 |
+
ch = [16, 32, 64, 128, 192]
|
114 |
+
|
115 |
+
resblock = AMPBlock1
|
116 |
+
resblock_kernel_sizes = [3,7,11]
|
117 |
+
resblock_dilation_sizes = [[1,3,5], [1,3,5], [1,3,5]]
|
118 |
+
self.num_kernels = 3
|
119 |
+
self.downs = nn.ModuleList()
|
120 |
+
for i, (u, k) in enumerate(zip(downsample_rates, downsample_kernel_sizes)):
|
121 |
+
self.downs.append(weight_norm(
|
122 |
+
Conv1d(ch[i], ch[i+1], k, u, padding=(k-1)//2)))
|
123 |
+
for i in range(4):
|
124 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
125 |
+
self.resblocks.append(resblock(ch[i+1], k, d, activation="snakebeta"))
|
126 |
+
|
127 |
+
activation_post = activations.SnakeBeta(ch[i+1], alpha_logscale=True)
|
128 |
+
self.activation_post = Activation1d(activation=activation_post)
|
129 |
+
|
130 |
+
self.conv_post = Conv1d(ch[i+1], hidden_channels, 7, 1, padding=3)
|
131 |
+
|
132 |
+
|
133 |
+
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
134 |
+
self.proj = nn.Conv1d(hidden_channels*2, out_channels * 2, 1)
|
135 |
+
|
136 |
+
def forward(self, x, x_audio, x_mask, g=None):
|
137 |
+
|
138 |
+
x_audio = self.down_pre(x_audio)
|
139 |
+
|
140 |
+
for i in range(4):
|
141 |
+
|
142 |
+
x_audio = self.downs[i](x_audio)
|
143 |
+
|
144 |
+
xs = None
|
145 |
+
for j in range(self.num_kernels):
|
146 |
+
if xs is None:
|
147 |
+
xs = self.resblocks[i*self.num_kernels+j](x_audio)
|
148 |
+
else:
|
149 |
+
xs += self.resblocks[i*self.num_kernels+j](x_audio)
|
150 |
+
x_audio = xs / self.num_kernels
|
151 |
+
|
152 |
+
x_audio = self.activation_post(x_audio)
|
153 |
+
x_audio = self.conv_post(x_audio)
|
154 |
+
|
155 |
+
x = self.pre(x) * x_mask
|
156 |
+
x = self.enc(x, x_mask, g=g)
|
157 |
+
|
158 |
+
x_audio = x_audio * x_mask
|
159 |
+
|
160 |
+
x = torch.cat([x, x_audio], dim=1)
|
161 |
+
|
162 |
+
stats = self.proj(x) * x_mask
|
163 |
+
|
164 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
165 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
166 |
+
return z, m, logs
|
167 |
+
|
168 |
+
class PosteriorSFEncoder(nn.Module):
|
169 |
+
def __init__(self,
|
170 |
+
src_channels,
|
171 |
+
out_channels,
|
172 |
+
hidden_channels,
|
173 |
+
kernel_size,
|
174 |
+
dilation_rate,
|
175 |
+
n_layers,
|
176 |
+
gin_channels=0):
|
177 |
+
super().__init__()
|
178 |
+
|
179 |
+
self.out_channels = out_channels
|
180 |
+
self.hidden_channels = hidden_channels
|
181 |
+
self.kernel_size = kernel_size
|
182 |
+
self.dilation_rate = dilation_rate
|
183 |
+
self.n_layers = n_layers
|
184 |
+
self.gin_channels = gin_channels
|
185 |
+
|
186 |
+
self.pre_source = nn.Conv1d(src_channels, hidden_channels, 1)
|
187 |
+
self.pre_filter = nn.Conv1d(1, hidden_channels, kernel_size=9, stride=4, padding=4)
|
188 |
+
self.source_enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers//2, gin_channels=gin_channels)
|
189 |
+
self.filter_enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers//2, gin_channels=gin_channels)
|
190 |
+
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers//2, gin_channels=gin_channels)
|
191 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
192 |
+
|
193 |
+
def forward(self, x_src, x_ftr, x_mask, g=None):
|
194 |
+
|
195 |
+
x_src = self.pre_source(x_src) * x_mask
|
196 |
+
x_ftr = self.pre_filter(x_ftr) * x_mask
|
197 |
+
x_src = self.source_enc(x_src, x_mask, g=g)
|
198 |
+
x_ftr = self.filter_enc(x_ftr, x_mask, g=g)
|
199 |
+
x = self.enc(x_src+x_ftr, x_mask, g=g)
|
200 |
+
stats = self.proj(x) * x_mask
|
201 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
202 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
203 |
+
return z, m, logs
|
204 |
+
|
205 |
+
|
206 |
+
class MelDecoder(nn.Module):
|
207 |
+
def __init__(self,
|
208 |
+
hidden_channels,
|
209 |
+
filter_channels,
|
210 |
+
n_heads,
|
211 |
+
n_layers,
|
212 |
+
kernel_size,
|
213 |
+
p_dropout,
|
214 |
+
mel_size=20,
|
215 |
+
gin_channels=0):
|
216 |
+
super().__init__()
|
217 |
+
|
218 |
+
self.hidden_channels = hidden_channels
|
219 |
+
self.filter_channels = filter_channels
|
220 |
+
self.n_heads = n_heads
|
221 |
+
self.n_layers = n_layers
|
222 |
+
self.kernel_size = kernel_size
|
223 |
+
self.p_dropout = p_dropout
|
224 |
+
|
225 |
+
self.conv_pre = Conv1d(hidden_channels, hidden_channels, 3, 1, padding=1)
|
226 |
+
|
227 |
+
self.encoder = attentions.Encoder(
|
228 |
+
hidden_channels,
|
229 |
+
filter_channels,
|
230 |
+
n_heads,
|
231 |
+
n_layers,
|
232 |
+
kernel_size,
|
233 |
+
p_dropout)
|
234 |
+
|
235 |
+
self.proj= nn.Conv1d(hidden_channels, mel_size, 1, bias=False)
|
236 |
+
|
237 |
+
if gin_channels != 0:
|
238 |
+
self.cond = nn.Conv1d(gin_channels, hidden_channels, 1)
|
239 |
+
|
240 |
+
def forward(self, x, x_mask, g=None):
|
241 |
+
|
242 |
+
x = self.conv_pre(x*x_mask)
|
243 |
+
if g is not None:
|
244 |
+
x = x + self.cond(g)
|
245 |
+
|
246 |
+
x = self.encoder(x * x_mask, x_mask)
|
247 |
+
x = self.proj(x) * x_mask
|
248 |
+
|
249 |
+
return x
|
250 |
+
|
251 |
+
class SourceNetwork(nn.Module):
|
252 |
+
def __init__(self, upsample_initial_channel=256):
|
253 |
+
super().__init__()
|
254 |
+
|
255 |
+
resblock_kernel_sizes = [3,5,7]
|
256 |
+
upsample_rates = [2,2]
|
257 |
+
initial_channel = 192
|
258 |
+
upsample_initial_channel = upsample_initial_channel
|
259 |
+
upsample_kernel_sizes = [4,4]
|
260 |
+
resblock_dilation_sizes = [[1,3,5], [1,3,5], [1,3,5]]
|
261 |
+
|
262 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
263 |
+
self.num_upsamples = len(upsample_rates)
|
264 |
+
|
265 |
+
self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
|
266 |
+
resblock = AMPBlock1
|
267 |
+
|
268 |
+
self.ups = nn.ModuleList()
|
269 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
270 |
+
self.ups.append(weight_norm(
|
271 |
+
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
272 |
+
k, u, padding=(k-u)//2)))
|
273 |
+
|
274 |
+
self.resblocks = nn.ModuleList()
|
275 |
+
for i in range(len(self.ups)):
|
276 |
+
ch = upsample_initial_channel//(2**(i+1))
|
277 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
278 |
+
self.resblocks.append(resblock(ch, k, d, activation="snakebeta"))
|
279 |
+
|
280 |
+
activation_post = activations.SnakeBeta(ch, alpha_logscale=True)
|
281 |
+
self.activation_post = Activation1d(activation=activation_post)
|
282 |
+
|
283 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
284 |
+
|
285 |
+
self.cond = Conv1d(256, upsample_initial_channel, 1)
|
286 |
+
|
287 |
+
self.ups.apply(init_weights)
|
288 |
+
|
289 |
+
|
290 |
+
def forward(self, x, g):
|
291 |
+
|
292 |
+
x = self.conv_pre(x) + self.cond(g)
|
293 |
+
|
294 |
+
for i in range(self.num_upsamples):
|
295 |
+
|
296 |
+
x = self.ups[i](x)
|
297 |
+
xs = None
|
298 |
+
for j in range(self.num_kernels):
|
299 |
+
if xs is None:
|
300 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
301 |
+
else:
|
302 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
303 |
+
x = xs / self.num_kernels
|
304 |
+
|
305 |
+
x = self.activation_post(x)
|
306 |
+
## Predictor
|
307 |
+
x_ = self.conv_post(x)
|
308 |
+
return x, x_
|
309 |
+
|
310 |
+
def remove_weight_norm(self):
|
311 |
+
print('Removing weight norm...')
|
312 |
+
for l in self.ups:
|
313 |
+
remove_weight_norm(l)
|
314 |
+
for l in self.resblocks:
|
315 |
+
l.remove_weight_norm()
|
316 |
+
|
317 |
+
class DBlock(nn.Module):
|
318 |
+
def __init__(self, input_size, hidden_size, factor):
|
319 |
+
super().__init__()
|
320 |
+
self.factor = factor
|
321 |
+
self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
|
322 |
+
self.conv = nn.ModuleList([
|
323 |
+
weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
|
324 |
+
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
|
325 |
+
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
|
326 |
+
])
|
327 |
+
self.conv.apply(init_weights)
|
328 |
+
def forward(self, x):
|
329 |
+
size = x.shape[-1] // self.factor
|
330 |
+
|
331 |
+
residual = self.residual_dense(x)
|
332 |
+
residual = F.interpolate(residual, size=size)
|
333 |
+
|
334 |
+
x = F.interpolate(x, size=size)
|
335 |
+
for layer in self.conv:
|
336 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
337 |
+
x = layer(x)
|
338 |
+
|
339 |
+
return x + residual
|
340 |
+
def remove_weight_norm(self):
|
341 |
+
for l in self.conv:
|
342 |
+
remove_weight_norm(l)
|
343 |
+
|
344 |
+
class AMPBlock1(torch.nn.Module):
|
345 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
346 |
+
super(AMPBlock1, self).__init__()
|
347 |
+
|
348 |
+
self.convs1 = nn.ModuleList([
|
349 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
350 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
351 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
352 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
353 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
354 |
+
padding=get_padding(kernel_size, dilation[2])))
|
355 |
+
])
|
356 |
+
self.convs1.apply(init_weights)
|
357 |
+
|
358 |
+
self.convs2 = nn.ModuleList([
|
359 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
360 |
+
padding=get_padding(kernel_size, 1))),
|
361 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
362 |
+
padding=get_padding(kernel_size, 1))),
|
363 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
364 |
+
padding=get_padding(kernel_size, 1)))
|
365 |
+
])
|
366 |
+
self.convs2.apply(init_weights)
|
367 |
+
|
368 |
+
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
369 |
+
|
370 |
+
|
371 |
+
self.activations = nn.ModuleList([
|
372 |
+
Activation1d(
|
373 |
+
activation=activations.SnakeBeta(channels, alpha_logscale=True))
|
374 |
+
for _ in range(self.num_layers)
|
375 |
+
])
|
376 |
+
|
377 |
+
def forward(self, x):
|
378 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
379 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
380 |
+
xt = a1(x)
|
381 |
+
xt = c1(xt)
|
382 |
+
xt = a2(xt)
|
383 |
+
xt = c2(xt)
|
384 |
+
x = xt + x
|
385 |
+
|
386 |
+
return x
|
387 |
+
|
388 |
+
def remove_weight_norm(self):
|
389 |
+
for l in self.convs1:
|
390 |
+
remove_weight_norm(l)
|
391 |
+
for l in self.convs2:
|
392 |
+
remove_weight_norm(l)
|
393 |
+
|
394 |
+
class Generator(torch.nn.Module):
|
395 |
+
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=256):
|
396 |
+
super(Generator, self).__init__()
|
397 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
398 |
+
self.num_upsamples = len(upsample_rates)
|
399 |
+
|
400 |
+
|
401 |
+
self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
|
402 |
+
resblock = AMPBlock1
|
403 |
+
|
404 |
+
self.ups = nn.ModuleList()
|
405 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
406 |
+
self.ups.append(weight_norm(
|
407 |
+
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
408 |
+
k, u, padding=(k-u)//2)))
|
409 |
+
|
410 |
+
self.resblocks = nn.ModuleList()
|
411 |
+
for i in range(len(self.ups)):
|
412 |
+
ch = upsample_initial_channel//(2**(i+1))
|
413 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
414 |
+
self.resblocks.append(resblock(ch, k, d, activation="snakebeta"))
|
415 |
+
|
416 |
+
activation_post = activations.SnakeBeta(ch, alpha_logscale=True)
|
417 |
+
self.activation_post = Activation1d(activation=activation_post)
|
418 |
+
|
419 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
420 |
+
self.ups.apply(init_weights)
|
421 |
+
|
422 |
+
if gin_channels != 0:
|
423 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
424 |
+
|
425 |
+
self.downs = DBlock(upsample_initial_channel//8, upsample_initial_channel, 4)
|
426 |
+
self.proj = Conv1d(upsample_initial_channel//8, upsample_initial_channel//2, 7, 1, padding=3)
|
427 |
+
|
428 |
+
def forward(self, x, pitch, g=None):
|
429 |
+
|
430 |
+
x = self.conv_pre(x) + self.downs(pitch) + self.cond(g)
|
431 |
+
|
432 |
+
for i in range(self.num_upsamples):
|
433 |
+
|
434 |
+
x = self.ups[i](x)
|
435 |
+
|
436 |
+
if i == 0:
|
437 |
+
pitch = self.proj(pitch)
|
438 |
+
x = x + pitch
|
439 |
+
|
440 |
+
xs = None
|
441 |
+
for j in range(self.num_kernels):
|
442 |
+
if xs is None:
|
443 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
444 |
+
else:
|
445 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
446 |
+
x = xs / self.num_kernels
|
447 |
+
|
448 |
+
x = self.activation_post(x)
|
449 |
+
x = self.conv_post(x)
|
450 |
+
x = torch.tanh(x)
|
451 |
+
return x
|
452 |
+
|
453 |
+
def remove_weight_norm(self):
|
454 |
+
print('Removing weight norm...')
|
455 |
+
for l in self.ups:
|
456 |
+
remove_weight_norm(l)
|
457 |
+
for l in self.resblocks:
|
458 |
+
l.remove_weight_norm()
|
459 |
+
for l in self.downs:
|
460 |
+
l.remove_weight_norm()
|
461 |
+
remove_weight_norm(self.conv_pre)
|
462 |
+
|
463 |
+
class DiscriminatorP(torch.nn.Module):
|
464 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
465 |
+
super(DiscriminatorP, self).__init__()
|
466 |
+
self.period = period
|
467 |
+
self.use_spectral_norm = use_spectral_norm
|
468 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
469 |
+
self.convs = nn.ModuleList([
|
470 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
471 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
472 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
473 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
474 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
475 |
+
])
|
476 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
477 |
+
|
478 |
+
def forward(self, x):
|
479 |
+
fmap = []
|
480 |
+
|
481 |
+
# 1d to 2d
|
482 |
+
b, c, t = x.shape
|
483 |
+
if t % self.period != 0: # pad first
|
484 |
+
n_pad = self.period - (t % self.period)
|
485 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
486 |
+
t = t + n_pad
|
487 |
+
x = x.view(b, c, t // self.period, self.period)
|
488 |
+
|
489 |
+
for l in self.convs:
|
490 |
+
x = l(x)
|
491 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
492 |
+
fmap.append(x)
|
493 |
+
x = self.conv_post(x)
|
494 |
+
fmap.append(x)
|
495 |
+
x = torch.flatten(x, 1, -1)
|
496 |
+
|
497 |
+
return x, fmap
|
498 |
+
|
499 |
+
class DiscriminatorR(torch.nn.Module):
|
500 |
+
def __init__(self, resolution, use_spectral_norm=False):
|
501 |
+
super(DiscriminatorR, self).__init__()
|
502 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
503 |
+
|
504 |
+
n_fft, hop_length, win_length = resolution
|
505 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
506 |
+
n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window,
|
507 |
+
normalized=True, center=False, pad_mode=None, power=None)
|
508 |
+
|
509 |
+
self.convs = nn.ModuleList([
|
510 |
+
norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))),
|
511 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
512 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))),
|
513 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))),
|
514 |
+
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
515 |
+
])
|
516 |
+
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
517 |
+
|
518 |
+
def forward(self, y):
|
519 |
+
fmap = []
|
520 |
+
|
521 |
+
x = self.spec_transform(y) # [B, 2, Freq, Frames, 2]
|
522 |
+
x = torch.cat([x.real, x.imag], dim=1)
|
523 |
+
x = rearrange(x, 'b c w t -> b c t w')
|
524 |
+
|
525 |
+
for l in self.convs:
|
526 |
+
x = l(x)
|
527 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
528 |
+
fmap.append(x)
|
529 |
+
x = self.conv_post(x)
|
530 |
+
fmap.append(x)
|
531 |
+
x = torch.flatten(x, 1, -1)
|
532 |
+
|
533 |
+
return x, fmap
|
534 |
+
|
535 |
+
|
536 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
537 |
+
def __init__(self, use_spectral_norm=False):
|
538 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
539 |
+
periods = [2,3,5,7,11]
|
540 |
+
resolutions = [[2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]]
|
541 |
+
|
542 |
+
discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))]
|
543 |
+
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
544 |
+
|
545 |
+
self.discriminators = nn.ModuleList(discs)
|
546 |
+
|
547 |
+
def forward(self, y, y_hat):
|
548 |
+
y_d_rs = []
|
549 |
+
y_d_gs = []
|
550 |
+
fmap_rs = []
|
551 |
+
fmap_gs = []
|
552 |
+
for i, d in enumerate(self.discriminators):
|
553 |
+
y_d_r, fmap_r = d(y)
|
554 |
+
y_d_g, fmap_g = d(y_hat)
|
555 |
+
y_d_rs.append(y_d_r)
|
556 |
+
y_d_gs.append(y_d_g)
|
557 |
+
fmap_rs.append(fmap_r)
|
558 |
+
fmap_gs.append(fmap_g)
|
559 |
+
|
560 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
561 |
+
|
562 |
+
class SynthesizerTrn(nn.Module):
|
563 |
+
"""
|
564 |
+
Synthesizer for Training
|
565 |
+
"""
|
566 |
+
|
567 |
+
def __init__(self,
|
568 |
+
|
569 |
+
spec_channels,
|
570 |
+
segment_size,
|
571 |
+
inter_channels,
|
572 |
+
hidden_channels,
|
573 |
+
filter_channels,
|
574 |
+
n_heads,
|
575 |
+
n_layers,
|
576 |
+
kernel_size,
|
577 |
+
p_dropout,
|
578 |
+
resblock,
|
579 |
+
resblock_kernel_sizes,
|
580 |
+
resblock_dilation_sizes,
|
581 |
+
upsample_rates,
|
582 |
+
upsample_initial_channel,
|
583 |
+
upsample_kernel_sizes,
|
584 |
+
gin_channels=256,
|
585 |
+
prosody_size=20,
|
586 |
+
uncond_ratio=0.,
|
587 |
+
cfg=False,
|
588 |
+
**kwargs):
|
589 |
+
|
590 |
+
super().__init__()
|
591 |
+
self.spec_channels = spec_channels
|
592 |
+
self.inter_channels = inter_channels
|
593 |
+
self.hidden_channels = hidden_channels
|
594 |
+
self.filter_channels = filter_channels
|
595 |
+
self.n_heads = n_heads
|
596 |
+
self.n_layers = n_layers
|
597 |
+
self.kernel_size = kernel_size
|
598 |
+
self.p_dropout = p_dropout
|
599 |
+
self.resblock = resblock
|
600 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
601 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
602 |
+
self.upsample_rates = upsample_rates
|
603 |
+
self.upsample_initial_channel = upsample_initial_channel
|
604 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
605 |
+
self.segment_size = segment_size
|
606 |
+
self.mel_size = prosody_size
|
607 |
+
|
608 |
+
self.enc_p_l = PosteriorSFEncoder(1024, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
609 |
+
self.flow_l = ResidualCouplingBlock_Transformer(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
|
610 |
+
|
611 |
+
self.enc_p = PosteriorSFEncoder(1024, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
612 |
+
self.enc_q = PosteriorAudioEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
613 |
+
self.flow = ResidualCouplingBlock_Transformer(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
|
614 |
+
|
615 |
+
self.mel_decoder = MelDecoder(inter_channels,
|
616 |
+
filter_channels,
|
617 |
+
n_heads=2,
|
618 |
+
n_layers=2,
|
619 |
+
kernel_size=5,
|
620 |
+
p_dropout=0.1,
|
621 |
+
mel_size=self.mel_size,
|
622 |
+
gin_channels=gin_channels)
|
623 |
+
|
624 |
+
self.dec = Generator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
625 |
+
self.sn = SourceNetwork(upsample_initial_channel//2)
|
626 |
+
self.emb_g = StyleEncoder(in_dim=80, hidden_dim=256, out_dim=gin_channels)
|
627 |
+
|
628 |
+
if cfg:
|
629 |
+
|
630 |
+
self.emb = torch.nn.Embedding(1, 256)
|
631 |
+
torch.nn.init.normal_(self.emb.weight, 0.0, 256 ** -0.5)
|
632 |
+
self.null = torch.LongTensor([0]).cuda()
|
633 |
+
self.uncond_ratio = uncond_ratio
|
634 |
+
self.cfg = cfg
|
635 |
+
@torch.no_grad()
|
636 |
+
def infer(self, x_mel, w2v, length, f0):
|
637 |
+
|
638 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(length, x_mel.size(2)), 1).to(x_mel.dtype)
|
639 |
+
|
640 |
+
# Speaker embedding from mel (Style Encoder)
|
641 |
+
g = self.emb_g(x_mel, x_mask).unsqueeze(-1)
|
642 |
+
|
643 |
+
z, _, _ = self.enc_p_l(w2v, f0, x_mask, g=g)
|
644 |
+
|
645 |
+
z = self.flow_l(z, x_mask, g=g, reverse=True)
|
646 |
+
z = self.flow(z, x_mask, g=g, reverse=True)
|
647 |
+
|
648 |
+
e, e_ = self.sn(z, g)
|
649 |
+
o = self.dec(z, e, g=g)
|
650 |
+
|
651 |
+
return o, e_
|
652 |
+
@torch.no_grad()
|
653 |
+
def voice_conversion(self, src, src_length, trg_mel, trg_length, f0, noise_scale = 0.333, uncond=False):
|
654 |
+
|
655 |
+
trg_mask = torch.unsqueeze(commons.sequence_mask(trg_length, trg_mel.size(2)), 1).to(trg_mel.dtype)
|
656 |
+
g = self.emb_g(trg_mel, trg_mask).unsqueeze(-1)
|
657 |
+
|
658 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(src_length, src.size(2)), 1).to(trg_mel.dtype)
|
659 |
+
z, m_p, logs_p = self.enc_p_l(src, f0, y_mask, g=g)
|
660 |
+
|
661 |
+
z = (m_p + torch.randn_like(m_p) * torch.exp(logs_p)*noise_scale) * y_mask
|
662 |
+
|
663 |
+
z = self.flow_l(z, y_mask, g=g, reverse=True)
|
664 |
+
z = self.flow(z, y_mask, g=g, reverse=True)
|
665 |
+
|
666 |
+
if uncond:
|
667 |
+
null_emb = self.emb(self.null) * math.sqrt(256)
|
668 |
+
g = null_emb.unsqueeze(-1)
|
669 |
+
|
670 |
+
e, _ = self.sn(z, g)
|
671 |
+
o = self.dec(z, e, g=g)
|
672 |
+
|
673 |
+
return o
|
674 |
+
@torch.no_grad()
|
675 |
+
def voice_conversion_noise_control(self, src, src_length, trg_mel, trg_length, f0, noise_scale = 0.333, uncond=False, denoise_ratio = 0):
|
676 |
+
|
677 |
+
trg_mask = torch.unsqueeze(commons.sequence_mask(trg_length, trg_mel.size(2)), 1).to(trg_mel.dtype)
|
678 |
+
g = self.emb_g(trg_mel, trg_mask).unsqueeze(-1)
|
679 |
+
|
680 |
+
g_org, g_denoise = g[:1, :, :], g[1:, :, :]
|
681 |
+
|
682 |
+
g_interpolation = (1-denoise_ratio)*g_org + denoise_ratio*g_denoise
|
683 |
+
|
684 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(src_length, src.size(2)), 1).to(trg_mel.dtype)
|
685 |
+
z, m_p, logs_p = self.enc_p_l(src, f0, y_mask, g=g_interpolation)
|
686 |
+
|
687 |
+
z = (m_p + torch.randn_like(m_p) * torch.exp(logs_p)*noise_scale) * y_mask
|
688 |
+
|
689 |
+
z = self.flow_l(z, y_mask, g=g_interpolation, reverse=True)
|
690 |
+
z = self.flow(z, y_mask, g=g_interpolation, reverse=True)
|
691 |
+
|
692 |
+
if uncond:
|
693 |
+
null_emb = self.emb(self.null) * math.sqrt(256)
|
694 |
+
g = null_emb.unsqueeze(-1)
|
695 |
+
|
696 |
+
e, _ = self.sn(z, g_interpolation)
|
697 |
+
o = self.dec(z, e, g=g_interpolation)
|
698 |
+
|
699 |
+
return o
|
700 |
+
@torch.no_grad()
|
701 |
+
def f0_extraction(self, x_linear, x_mel, length, x_audio, noise_scale = 0.333):
|
702 |
+
|
703 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(length, x_mel.size(2)), 1).to(x_mel.dtype)
|
704 |
+
|
705 |
+
# Speaker embedding from mel (Style Encoder)
|
706 |
+
g = self.emb_g(x_mel, x_mask).unsqueeze(-1)
|
707 |
+
|
708 |
+
# posterior encoder from linear spec.
|
709 |
+
_, m_q, logs_q= self.enc_q(x_linear, x_audio, x_mask, g=g)
|
710 |
+
z = (m_q + torch.randn_like(m_q) * torch.exp(logs_q)*noise_scale)
|
711 |
+
|
712 |
+
# Source Networks
|
713 |
+
_, e_ = self.sn(z, g)
|
714 |
+
|
715 |
+
return e_
|
716 |
+
|
inference.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from scipy.io.wavfile import write
|
6 |
+
import torchaudio
|
7 |
+
import utils
|
8 |
+
from Mels_preprocess import MelSpectrogramFixed
|
9 |
+
|
10 |
+
from hierspeechpp_speechsynthesizer import (
|
11 |
+
SynthesizerTrn
|
12 |
+
)
|
13 |
+
from ttv_v1.text import text_to_sequence
|
14 |
+
from ttv_v1.t2w2v_transformer import SynthesizerTrn as Text2W2V
|
15 |
+
from speechsr24k.speechsr import SynthesizerTrn as AudioSR
|
16 |
+
from speechsr48k.speechsr import SynthesizerTrn as AudioSR48
|
17 |
+
from denoiser.generator import MPNet
|
18 |
+
from denoiser.infer import denoise
|
19 |
+
|
20 |
+
seed = 1111
|
21 |
+
torch.manual_seed(seed)
|
22 |
+
torch.cuda.manual_seed(seed)
|
23 |
+
np.random.seed(seed)
|
24 |
+
|
25 |
+
def load_text(fp):
|
26 |
+
with open(fp, 'r') as f:
|
27 |
+
filelist = [line.strip() for line in f.readlines()]
|
28 |
+
return filelist
|
29 |
+
def load_checkpoint(filepath, device):
|
30 |
+
print(filepath)
|
31 |
+
assert os.path.isfile(filepath)
|
32 |
+
print("Loading '{}'".format(filepath))
|
33 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
34 |
+
print("Complete.")
|
35 |
+
return checkpoint_dict
|
36 |
+
def get_param_num(model):
|
37 |
+
num_param = sum(param.numel() for param in model.parameters())
|
38 |
+
return num_param
|
39 |
+
def intersperse(lst, item):
|
40 |
+
result = [item] * (len(lst) * 2 + 1)
|
41 |
+
result[1::2] = lst
|
42 |
+
return result
|
43 |
+
|
44 |
+
def add_blank_token(text):
|
45 |
+
|
46 |
+
text_norm = intersperse(text, 0)
|
47 |
+
text_norm = torch.LongTensor(text_norm)
|
48 |
+
return text_norm
|
49 |
+
|
50 |
+
def tts(text, a, hierspeech):
|
51 |
+
|
52 |
+
net_g, text2w2v, audiosr, denoiser, mel_fn = hierspeech
|
53 |
+
|
54 |
+
os.makedirs(a.output_dir, exist_ok=True)
|
55 |
+
text = text_to_sequence(str(text), ["english_cleaners2"])
|
56 |
+
token = add_blank_token(text).unsqueeze(0).cuda()
|
57 |
+
token_length = torch.LongTensor([token.size(-1)]).cuda()
|
58 |
+
|
59 |
+
# Prompt load
|
60 |
+
audio, sample_rate = torchaudio.load(a.input_prompt)
|
61 |
+
|
62 |
+
# support only single channel
|
63 |
+
audio = audio[:1,:]
|
64 |
+
# Resampling
|
65 |
+
if sample_rate != 16000:
|
66 |
+
audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window")
|
67 |
+
if a.scale_norm == 'prompt':
|
68 |
+
prompt_audio_max = torch.max(audio.abs())
|
69 |
+
|
70 |
+
# We utilize a hop size of 320 but denoiser uses a hop size of 400 so we utilize a hop size of 1600
|
71 |
+
ori_prompt_len = audio.shape[-1]
|
72 |
+
p = (ori_prompt_len // 1600 + 1) * 1600 - ori_prompt_len
|
73 |
+
audio = torch.nn.functional.pad(audio, (0, p), mode='constant').data
|
74 |
+
|
75 |
+
file_name = os.path.splitext(os.path.basename(a.input_prompt))[0]
|
76 |
+
|
77 |
+
# If you have a memory issue during denosing the prompt, try to denoise the prompt with cpu before TTS
|
78 |
+
# We will have a plan to replace a memory-efficient denoiser
|
79 |
+
if a.denoise_ratio == 0:
|
80 |
+
audio = torch.cat([audio.cuda(), audio.cuda()], dim=0)
|
81 |
+
else:
|
82 |
+
with torch.no_grad():
|
83 |
+
denoised_audio = denoise(audio.squeeze(0).cuda(), denoiser, hps_denoiser)
|
84 |
+
audio = torch.cat([audio.cuda(), denoised_audio[:,:audio.shape[-1]]], dim=0)
|
85 |
+
|
86 |
+
|
87 |
+
audio = audio[:,:ori_prompt_len] # 20231108 We found that large size of padding decreases a performance so we remove the paddings after denosing.
|
88 |
+
|
89 |
+
src_mel = mel_fn(audio.cuda())
|
90 |
+
|
91 |
+
src_length = torch.LongTensor([src_mel.size(2)]).to(device)
|
92 |
+
src_length2 = torch.cat([src_length,src_length], dim=0)
|
93 |
+
|
94 |
+
## TTV (Text --> W2V, F0)
|
95 |
+
with torch.no_grad():
|
96 |
+
w2v_x, pitch = text2w2v.infer_noise_control(token, token_length, src_mel, src_length2, noise_scale=a.noise_scale_ttv, denoise_ratio=a.denoise_ratio)
|
97 |
+
|
98 |
+
src_length = torch.LongTensor([w2v_x.size(2)]).cuda()
|
99 |
+
|
100 |
+
## Pitch Clipping
|
101 |
+
pitch[pitch<torch.log(torch.tensor([55]).cuda())] = 0
|
102 |
+
|
103 |
+
## Hierarchical Speech Synthesizer (W2V, F0 --> 16k Audio)
|
104 |
+
converted_audio = \
|
105 |
+
net_g.voice_conversion_noise_control(w2v_x, src_length, src_mel, src_length2, pitch, noise_scale=a.noise_scale_vc, denoise_ratio=a.denoise_ratio)
|
106 |
+
|
107 |
+
## SpeechSR (Optional) (16k Audio --> 24k or 48k Audio)
|
108 |
+
if a.output_sr == 48000 or 24000:
|
109 |
+
converted_audio = audiosr(converted_audio)
|
110 |
+
|
111 |
+
converted_audio = converted_audio.squeeze()
|
112 |
+
|
113 |
+
if a.scale_norm == 'prompt':
|
114 |
+
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * prompt_audio_max
|
115 |
+
else:
|
116 |
+
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * 0.999
|
117 |
+
|
118 |
+
converted_audio = converted_audio.cpu().numpy().astype('int16')
|
119 |
+
|
120 |
+
file_name2 = "{}.wav".format(file_name)
|
121 |
+
output_file = os.path.join(a.output_dir, file_name2)
|
122 |
+
|
123 |
+
if a.output_sr == 48000:
|
124 |
+
write(output_file, 48000, converted_audio)
|
125 |
+
elif a.output_sr == 24000:
|
126 |
+
write(output_file, 24000, converted_audio)
|
127 |
+
else:
|
128 |
+
write(output_file, 16000, converted_audio)
|
129 |
+
|
130 |
+
def model_load(a):
|
131 |
+
mel_fn = MelSpectrogramFixed(
|
132 |
+
sample_rate=hps.data.sampling_rate,
|
133 |
+
n_fft=hps.data.filter_length,
|
134 |
+
win_length=hps.data.win_length,
|
135 |
+
hop_length=hps.data.hop_length,
|
136 |
+
f_min=hps.data.mel_fmin,
|
137 |
+
f_max=hps.data.mel_fmax,
|
138 |
+
n_mels=hps.data.n_mel_channels,
|
139 |
+
window_fn=torch.hann_window
|
140 |
+
).cuda()
|
141 |
+
|
142 |
+
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
|
143 |
+
hps.train.segment_size // hps.data.hop_length,
|
144 |
+
**hps.model).cuda()
|
145 |
+
net_g.load_state_dict(torch.load(a.ckpt))
|
146 |
+
_ = net_g.eval()
|
147 |
+
|
148 |
+
text2w2v = Text2W2V(hps.data.filter_length // 2 + 1,
|
149 |
+
hps.train.segment_size // hps.data.hop_length,
|
150 |
+
**hps_t2w2v.model).cuda()
|
151 |
+
text2w2v.load_state_dict(torch.load(a.ckpt_text2w2v))
|
152 |
+
text2w2v.eval()
|
153 |
+
|
154 |
+
if a.output_sr == 48000:
|
155 |
+
audiosr = AudioSR48(h_sr48.data.n_mel_channels,
|
156 |
+
h_sr48.train.segment_size // h_sr48.data.hop_length,
|
157 |
+
**h_sr48.model).cuda()
|
158 |
+
utils.load_checkpoint(a.ckpt_sr48, audiosr, None)
|
159 |
+
audiosr.eval()
|
160 |
+
|
161 |
+
elif a.output_sr == 24000:
|
162 |
+
audiosr = AudioSR(h_sr.data.n_mel_channels,
|
163 |
+
h_sr.train.segment_size // h_sr.data.hop_length,
|
164 |
+
**h_sr.model).cuda()
|
165 |
+
utils.load_checkpoint(a.ckpt_sr, audiosr, None)
|
166 |
+
audiosr.eval()
|
167 |
+
|
168 |
+
else:
|
169 |
+
audiosr = None
|
170 |
+
|
171 |
+
denoiser = MPNet(hps_denoiser).cuda()
|
172 |
+
state_dict = load_checkpoint(a.denoiser_ckpt, device)
|
173 |
+
denoiser.load_state_dict(state_dict['generator'])
|
174 |
+
denoiser.eval()
|
175 |
+
return net_g, text2w2v, audiosr, denoiser, mel_fn
|
176 |
+
|
177 |
+
def inference(a):
|
178 |
+
|
179 |
+
hierspeech = model_load(a)
|
180 |
+
# Input Text
|
181 |
+
text = load_text(a.input_txt)
|
182 |
+
# text = "hello I'm hierspeech"
|
183 |
+
|
184 |
+
tts(text, a, hierspeech)
|
185 |
+
|
186 |
+
def main():
|
187 |
+
print('Initializing Inference Process..')
|
188 |
+
|
189 |
+
parser = argparse.ArgumentParser()
|
190 |
+
parser.add_argument('--input_prompt', default='example/reference_4.wav')
|
191 |
+
parser.add_argument('--input_txt', default='example/reference_4.txt')
|
192 |
+
parser.add_argument('--output_dir', default='output')
|
193 |
+
parser.add_argument('--ckpt', default='./logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth')
|
194 |
+
parser.add_argument('--ckpt_text2w2v', '-ct', help='text2w2v checkpoint path', default='./logs/ttv_libritts_v1/ttv_lt960_ckpt.pth')
|
195 |
+
parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth')
|
196 |
+
parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth')
|
197 |
+
parser.add_argument('--denoiser_ckpt', type=str, default='denoiser/g_best')
|
198 |
+
parser.add_argument('--scale_norm', type=str, default='max')
|
199 |
+
parser.add_argument('--output_sr', type=float, default=48000)
|
200 |
+
parser.add_argument('--noise_scale_ttv', type=float,
|
201 |
+
default=0.333)
|
202 |
+
parser.add_argument('--noise_scale_vc', type=float,
|
203 |
+
default=0.333)
|
204 |
+
parser.add_argument('--denoise_ratio', type=float,
|
205 |
+
default=0.8)
|
206 |
+
a = parser.parse_args()
|
207 |
+
|
208 |
+
global device, hps, hps_t2w2v,h_sr,h_sr48, hps_denoiser
|
209 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
210 |
+
|
211 |
+
hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json'))
|
212 |
+
hps_t2w2v = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_text2w2v)[0], 'config.json'))
|
213 |
+
h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') )
|
214 |
+
h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') )
|
215 |
+
hps_denoiser = utils.get_hparams_from_file(os.path.join(os.path.split(a.denoiser_ckpt)[0], 'config.json'))
|
216 |
+
|
217 |
+
inference(a)
|
218 |
+
|
219 |
+
if __name__ == '__main__':
|
220 |
+
main()
|
inference.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --ckpt "logs/hierspeechpp_libritts460/hierspeechpp_lt460_ckpt.pth" \ LibriTTS-460
|
2 |
+
# --ckpt "logs/hierspeechpp_libritts960/hierspeechpp_lt960_ckpt.pth" \ LibriTTS-960
|
3 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v1_ckpt.pth" \ Large_v1 epoch 60 (paper version)
|
4 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \ Large_v2 epoch 110 (08. Nov. 2023)
|
5 |
+
|
6 |
+
CUDA_VISIBLE_DEVICES=0 python3 inference.py \
|
7 |
+
--ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \
|
8 |
+
--ckpt_text2w2v "logs/ttv_libritts_v1/ttv_lt960_ckpt.pth" \
|
9 |
+
--output_dir "tts_results_eng_kor_v2" \
|
10 |
+
--noise_scale_vc "0.333" \
|
11 |
+
--noise_scale_ttv "0.333" \
|
12 |
+
--denoise_ratio "0"
|
inference_speechsr.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from scipy.io.wavfile import write
|
6 |
+
import torchaudio
|
7 |
+
import utils
|
8 |
+
|
9 |
+
from speechsr24k.speechsr import SynthesizerTrn as SpeechSR24
|
10 |
+
from speechsr48k.speechsr import SynthesizerTrn as SpeechSR48
|
11 |
+
|
12 |
+
seed = 1111
|
13 |
+
torch.manual_seed(seed)
|
14 |
+
torch.cuda.manual_seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
|
17 |
+
def get_param_num(model):
|
18 |
+
num_param = sum(param.numel() for param in model.parameters())
|
19 |
+
return num_param
|
20 |
+
|
21 |
+
def SuperResoltuion(a, hierspeech):
|
22 |
+
|
23 |
+
speechsr = hierspeech
|
24 |
+
|
25 |
+
os.makedirs(a.output_dir, exist_ok=True)
|
26 |
+
|
27 |
+
# Prompt load
|
28 |
+
audio, sample_rate = torchaudio.load(a.input_speech)
|
29 |
+
|
30 |
+
# support only single channel
|
31 |
+
audio = audio[:1,:]
|
32 |
+
# Resampling
|
33 |
+
if sample_rate != 16000:
|
34 |
+
audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window")
|
35 |
+
file_name = os.path.splitext(os.path.basename(a.input_speech))[0]
|
36 |
+
## SpeechSR (Optional) (16k Audio --> 24k or 48k Audio)
|
37 |
+
with torch.no_grad():
|
38 |
+
converted_audio = speechsr(audio.unsqueeze(1).cuda())
|
39 |
+
converted_audio = converted_audio.squeeze()
|
40 |
+
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 0.999 * 32767.0
|
41 |
+
converted_audio = converted_audio.cpu().numpy().astype('int16')
|
42 |
+
|
43 |
+
file_name2 = "{}.wav".format(file_name)
|
44 |
+
output_file = os.path.join(a.output_dir, file_name2)
|
45 |
+
|
46 |
+
if a.output_sr == 48000:
|
47 |
+
write(output_file, 48000, converted_audio)
|
48 |
+
else:
|
49 |
+
write(output_file, 24000, converted_audio)
|
50 |
+
|
51 |
+
|
52 |
+
def model_load(a):
|
53 |
+
if a.output_sr == 48000:
|
54 |
+
speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
|
55 |
+
h_sr48.train.segment_size // h_sr48.data.hop_length,
|
56 |
+
**h_sr48.model).cuda()
|
57 |
+
utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
|
58 |
+
speechsr.eval()
|
59 |
+
else:
|
60 |
+
# 24000 Hz
|
61 |
+
speechsr = SpeechSR24(h_sr.data.n_mel_channels,
|
62 |
+
h_sr.train.segment_size // h_sr.data.hop_length,
|
63 |
+
**h_sr.model).cuda()
|
64 |
+
utils.load_checkpoint(a.ckpt_sr, speechsr, None)
|
65 |
+
speechsr.eval()
|
66 |
+
return speechsr
|
67 |
+
|
68 |
+
def inference(a):
|
69 |
+
|
70 |
+
speechsr = model_load(a)
|
71 |
+
SuperResoltuion(a, speechsr)
|
72 |
+
|
73 |
+
def main():
|
74 |
+
print('Initializing Inference Process..')
|
75 |
+
|
76 |
+
parser = argparse.ArgumentParser()
|
77 |
+
parser.add_argument('--input_speech', default='example/reference_4.wav')
|
78 |
+
parser.add_argument('--output_dir', default='SR_results')
|
79 |
+
parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth')
|
80 |
+
parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth')
|
81 |
+
parser.add_argument('--output_sr', type=float, default=48000)
|
82 |
+
a = parser.parse_args()
|
83 |
+
|
84 |
+
global device, h_sr, h_sr48
|
85 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
86 |
+
|
87 |
+
h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') )
|
88 |
+
h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') )
|
89 |
+
|
90 |
+
|
91 |
+
inference(a)
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
main()
|
inference_vc.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from scipy.io.wavfile import write
|
6 |
+
import torchaudio
|
7 |
+
import utils
|
8 |
+
from Mels_preprocess import MelSpectrogramFixed
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from hierspeechpp_speechsynthesizer import (
|
11 |
+
SynthesizerTrn, Wav2vec2
|
12 |
+
)
|
13 |
+
from ttv_v1.text import text_to_sequence
|
14 |
+
from ttv_v1.t2w2v_transformer import SynthesizerTrn as Text2W2V
|
15 |
+
from speechsr24k.speechsr import SynthesizerTrn as SpeechSR24
|
16 |
+
from speechsr48k.speechsr import SynthesizerTrn as SpeechSR48
|
17 |
+
from denoiser.generator import MPNet
|
18 |
+
from denoiser.infer import denoise
|
19 |
+
|
20 |
+
import amfm_decompy.basic_tools as basic
|
21 |
+
import amfm_decompy.pYAAPT as pYAAPT
|
22 |
+
|
23 |
+
seed = 1111
|
24 |
+
torch.manual_seed(seed)
|
25 |
+
torch.cuda.manual_seed(seed)
|
26 |
+
np.random.seed(seed)
|
27 |
+
|
28 |
+
def get_yaapt_f0(audio, rate=16000, interp=False):
|
29 |
+
frame_length = 20.0
|
30 |
+
to_pad = int(frame_length / 1000 * rate) // 2
|
31 |
+
|
32 |
+
f0s = []
|
33 |
+
for y in audio.astype(np.float64):
|
34 |
+
y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0)
|
35 |
+
signal = basic.SignalObj(y_pad, rate)
|
36 |
+
pitch = pYAAPT.yaapt(signal, **{'frame_length': frame_length, 'frame_space': 5.0, 'nccf_thresh1': 0.25,
|
37 |
+
'tda_frame_length': 25.0, 'f0_max':1100})
|
38 |
+
if interp:
|
39 |
+
f0s += [pitch.samp_interp[None, None, :]]
|
40 |
+
else:
|
41 |
+
f0s += [pitch.samp_values[None, None, :]]
|
42 |
+
f0 = np.vstack(f0s)
|
43 |
+
return f0
|
44 |
+
|
45 |
+
def load_text(fp):
|
46 |
+
with open(fp, 'r') as f:
|
47 |
+
filelist = [line.strip() for line in f.readlines()]
|
48 |
+
return filelist
|
49 |
+
def load_checkpoint(filepath, device):
|
50 |
+
print(filepath)
|
51 |
+
assert os.path.isfile(filepath)
|
52 |
+
print("Loading '{}'".format(filepath))
|
53 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
54 |
+
print("Complete.")
|
55 |
+
return checkpoint_dict
|
56 |
+
def get_param_num(model):
|
57 |
+
num_param = sum(param.numel() for param in model.parameters())
|
58 |
+
return num_param
|
59 |
+
def intersperse(lst, item):
|
60 |
+
result = [item] * (len(lst) * 2 + 1)
|
61 |
+
result[1::2] = lst
|
62 |
+
return result
|
63 |
+
|
64 |
+
def add_blank_token(text):
|
65 |
+
|
66 |
+
text_norm = intersperse(text, 0)
|
67 |
+
text_norm = torch.LongTensor(text_norm)
|
68 |
+
return text_norm
|
69 |
+
|
70 |
+
def VC(a, hierspeech):
|
71 |
+
|
72 |
+
net_g, speechsr, denoiser, mel_fn, w2v = hierspeech
|
73 |
+
|
74 |
+
os.makedirs(a.output_dir, exist_ok=True)
|
75 |
+
|
76 |
+
source_audio, sample_rate = torchaudio.load(a.source_speech)
|
77 |
+
if sample_rate != 16000:
|
78 |
+
source_audio = torchaudio.functional.resample(source_audio, sample_rate, 16000, resampling_method="kaiser_window")
|
79 |
+
p = (source_audio.shape[-1] // 1280 + 1) * 1280 - source_audio.shape[-1]
|
80 |
+
source_audio = torch.nn.functional.pad(source_audio, (0, p), mode='constant').data
|
81 |
+
file_name_s = os.path.splitext(os.path.basename(a.source_speech))[0]
|
82 |
+
|
83 |
+
try:
|
84 |
+
f0 = get_yaapt_f0(source_audio.numpy())
|
85 |
+
except:
|
86 |
+
f0 = np.zeros((1, 1, source_audio.shape[-1] // 80))
|
87 |
+
f0 = f0.astype(np.float32)
|
88 |
+
f0 = f0.squeeze(0)
|
89 |
+
|
90 |
+
ii = f0 != 0
|
91 |
+
f0[ii] = (f0[ii] - f0[ii].mean()) / f0[ii].std()
|
92 |
+
|
93 |
+
y_pad = F.pad(source_audio, (40, 40), "reflect")
|
94 |
+
x_w2v = w2v(y_pad.cuda())
|
95 |
+
x_length = torch.LongTensor([x_w2v.size(2)]).to(device)
|
96 |
+
|
97 |
+
# Prompt load
|
98 |
+
target_audio, sample_rate = torchaudio.load(a.target_speech)
|
99 |
+
# support only single channel
|
100 |
+
target_audio = target_audio[:1,:]
|
101 |
+
# Resampling
|
102 |
+
if sample_rate != 16000:
|
103 |
+
target_audio = torchaudio.functional.resample(target_audio, sample_rate, 16000, resampling_method="kaiser_window")
|
104 |
+
if a.scale_norm == 'prompt':
|
105 |
+
prompt_audio_max = torch.max(target_audio.abs())
|
106 |
+
try:
|
107 |
+
t_f0 = get_yaapt_f0(target_audio.numpy())
|
108 |
+
except:
|
109 |
+
t_f0 = np.zeros((1, 1, target_audio.shape[-1] // 80))
|
110 |
+
t_f0 = t_f0.astype(np.float32)
|
111 |
+
t_f0 = t_f0.squeeze(0)
|
112 |
+
j = t_f0 != 0
|
113 |
+
|
114 |
+
f0[ii] = ((f0[ii] * t_f0[j].std()) + t_f0[j].mean()).clip(min=0)
|
115 |
+
denorm_f0 = torch.log(torch.FloatTensor(f0+1).cuda())
|
116 |
+
# We utilize a hop size of 320 but denoiser uses a hop size of 400 so we utilize a hop size of 1600
|
117 |
+
ori_prompt_len = target_audio.shape[-1]
|
118 |
+
p = (ori_prompt_len // 1600 + 1) * 1600 - ori_prompt_len
|
119 |
+
target_audio = torch.nn.functional.pad(target_audio, (0, p), mode='constant').data
|
120 |
+
|
121 |
+
file_name_t = os.path.splitext(os.path.basename(a.target_speech))[0]
|
122 |
+
|
123 |
+
# If you have a memory issue during denosing the prompt, try to denoise the prompt with cpu before TTS
|
124 |
+
# We will have a plan to replace a memory-efficient denoiser
|
125 |
+
if a.denoise_ratio == 0:
|
126 |
+
target_audio = torch.cat([target_audio.cuda(), target_audio.cuda()], dim=0)
|
127 |
+
else:
|
128 |
+
with torch.no_grad():
|
129 |
+
denoised_audio = denoise(target_audio.squeeze(0).cuda(), denoiser, hps_denoiser)
|
130 |
+
target_audio = torch.cat([target_audio.cuda(), denoised_audio[:,:target_audio.shape[-1]]], dim=0)
|
131 |
+
|
132 |
+
|
133 |
+
target_audio = target_audio[:,:ori_prompt_len] # 20231108 We found that large size of padding decreases a performance so we remove the paddings after denosing.
|
134 |
+
|
135 |
+
trg_mel = mel_fn(target_audio.cuda())
|
136 |
+
|
137 |
+
trg_length = torch.LongTensor([trg_mel.size(2)]).to(device)
|
138 |
+
trg_length2 = torch.cat([trg_length,trg_length], dim=0)
|
139 |
+
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
|
143 |
+
## Hierarchical Speech Synthesizer (W2V, F0 --> 16k Audio)
|
144 |
+
converted_audio = \
|
145 |
+
net_g.voice_conversion_noise_control(x_w2v, x_length, trg_mel, trg_length2, denorm_f0, noise_scale=a.noise_scale_vc, denoise_ratio=a.denoise_ratio)
|
146 |
+
|
147 |
+
## SpeechSR (Optional) (16k Audio --> 24k or 48k Audio)
|
148 |
+
if a.output_sr == 48000 or 24000:
|
149 |
+
converted_audio = speechsr(converted_audio)
|
150 |
+
|
151 |
+
converted_audio = converted_audio.squeeze()
|
152 |
+
|
153 |
+
if a.scale_norm == 'prompt':
|
154 |
+
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * prompt_audio_max
|
155 |
+
else:
|
156 |
+
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * 0.999
|
157 |
+
|
158 |
+
converted_audio = converted_audio.cpu().numpy().astype('int16')
|
159 |
+
|
160 |
+
file_name2 = "{}.wav".format(file_name_s+"_to_"+file_name_t)
|
161 |
+
output_file = os.path.join(a.output_dir, file_name2)
|
162 |
+
|
163 |
+
if a.output_sr == 48000:
|
164 |
+
write(output_file, 48000, converted_audio)
|
165 |
+
elif a.output_sr == 24000:
|
166 |
+
write(output_file, 24000, converted_audio)
|
167 |
+
else:
|
168 |
+
write(output_file, 16000, converted_audio)
|
169 |
+
|
170 |
+
def model_load(a):
|
171 |
+
mel_fn = MelSpectrogramFixed(
|
172 |
+
sample_rate=hps.data.sampling_rate,
|
173 |
+
n_fft=hps.data.filter_length,
|
174 |
+
win_length=hps.data.win_length,
|
175 |
+
hop_length=hps.data.hop_length,
|
176 |
+
f_min=hps.data.mel_fmin,
|
177 |
+
f_max=hps.data.mel_fmax,
|
178 |
+
n_mels=hps.data.n_mel_channels,
|
179 |
+
window_fn=torch.hann_window
|
180 |
+
).cuda()
|
181 |
+
w2v = Wav2vec2().cuda()
|
182 |
+
|
183 |
+
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
|
184 |
+
hps.train.segment_size // hps.data.hop_length,
|
185 |
+
**hps.model).cuda()
|
186 |
+
net_g.load_state_dict(torch.load(a.ckpt))
|
187 |
+
_ = net_g.eval()
|
188 |
+
|
189 |
+
if a.output_sr == 48000:
|
190 |
+
speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
|
191 |
+
h_sr48.train.segment_size // h_sr48.data.hop_length,
|
192 |
+
**h_sr48.model).cuda()
|
193 |
+
utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
|
194 |
+
speechsr.eval()
|
195 |
+
|
196 |
+
elif a.output_sr == 24000:
|
197 |
+
speechsr = SpeechSR24(h_sr.data.n_mel_channels,
|
198 |
+
h_sr.train.segment_size // h_sr.data.hop_length,
|
199 |
+
**h_sr.model).cuda()
|
200 |
+
utils.load_checkpoint(a.ckpt_sr, speechsr, None)
|
201 |
+
speechsr.eval()
|
202 |
+
|
203 |
+
else:
|
204 |
+
speechsr = None
|
205 |
+
|
206 |
+
denoiser = MPNet(hps_denoiser).cuda()
|
207 |
+
state_dict = load_checkpoint(a.denoiser_ckpt, device)
|
208 |
+
denoiser.load_state_dict(state_dict['generator'])
|
209 |
+
denoiser.eval()
|
210 |
+
return net_g, speechsr, denoiser, mel_fn, w2v
|
211 |
+
|
212 |
+
def inference(a):
|
213 |
+
|
214 |
+
hierspeech = model_load(a)
|
215 |
+
|
216 |
+
VC(a, hierspeech)
|
217 |
+
|
218 |
+
def main():
|
219 |
+
print('Initializing Inference Process..')
|
220 |
+
|
221 |
+
parser = argparse.ArgumentParser()
|
222 |
+
parser.add_argument('--source_speech', default='example/reference_2.wav')
|
223 |
+
parser.add_argument('--target_speech', default='example/reference_1.wav')
|
224 |
+
parser.add_argument('--output_dir', default='output')
|
225 |
+
parser.add_argument('--ckpt', default='./logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth')
|
226 |
+
parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth')
|
227 |
+
parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth')
|
228 |
+
parser.add_argument('--denoiser_ckpt', type=str, default='denoiser/g_best')
|
229 |
+
parser.add_argument('--scale_norm', type=str, default='max')
|
230 |
+
parser.add_argument('--output_sr', type=float, default=48000)
|
231 |
+
parser.add_argument('--noise_scale_ttv', type=float,
|
232 |
+
default=0.333)
|
233 |
+
parser.add_argument('--noise_scale_vc', type=float,
|
234 |
+
default=0.333)
|
235 |
+
parser.add_argument('--denoise_ratio', type=float,
|
236 |
+
default=0.8)
|
237 |
+
a = parser.parse_args()
|
238 |
+
|
239 |
+
global device, hps, h_sr,h_sr48, hps_denoiser
|
240 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
241 |
+
|
242 |
+
hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json'))
|
243 |
+
h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') )
|
244 |
+
h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') )
|
245 |
+
hps_denoiser = utils.get_hparams_from_file(os.path.join(os.path.split(a.denoiser_ckpt)[0], 'config.json'))
|
246 |
+
|
247 |
+
inference(a)
|
248 |
+
|
249 |
+
if __name__ == '__main__':
|
250 |
+
main()
|
inference_vc.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --ckpt "logs/hierspeechpp_libritts460/hierspeechpp_lt460_ckpt.pth" \ LibriTTS-460
|
2 |
+
# --ckpt "logs/hierspeechpp_libritts960/hierspeechpp_lt960_ckpt.pth" \ LibriTTS-960
|
3 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v1_ckpt.pth" \ Large_v1 epoch 60 (paper version)
|
4 |
+
# --ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \ Large_v2 epoch 110 (08. Nov. 2023)
|
5 |
+
|
6 |
+
CUDA_VISIBLE_DEVICES=0 python3 inference_vc.py \
|
7 |
+
--ckpt "logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth" \
|
8 |
+
--output_dir "vc_results_eng_kor_v2" \
|
9 |
+
--noise_scale_vc "0.333" \
|
10 |
+
--noise_scale_ttv "0.333" \
|
11 |
+
--denoise_ratio "0"
|
logs/hierspeechpp_eng_kor/config.json
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"eval_interval": 10000,
|
5 |
+
"save_interval": 10000,
|
6 |
+
"seed": 1234,
|
7 |
+
"epochs": 20000,
|
8 |
+
"learning_rate": 1e-4,
|
9 |
+
"betas": [0.8, 0.99],
|
10 |
+
"eps": 1e-9,
|
11 |
+
"batch_size": 20,
|
12 |
+
"fp16_run": false,
|
13 |
+
"lr_decay": 0.999,
|
14 |
+
"segment_size": 61440,
|
15 |
+
"init_lr_ratio": 1,
|
16 |
+
"warmup_epochs": 0,
|
17 |
+
"c_mel": 45,
|
18 |
+
"c_kl": 1.0,
|
19 |
+
"c_bi_kl": 0.5,
|
20 |
+
"c_mixup": 1,
|
21 |
+
"c_pho": 45.0,
|
22 |
+
"c_f0": 45,
|
23 |
+
"aug": true
|
24 |
+
},
|
25 |
+
"data": {
|
26 |
+
"train_filelist_path": "filelists/train_eng_kor_wav.txt",
|
27 |
+
"test_filelist_path": "filelists/test_hier_16k_72_wav.txt",
|
28 |
+
"text_cleaners":["english_cleaners2"],
|
29 |
+
"max_wav_value": 32768.0,
|
30 |
+
"sampling_rate": 16000,
|
31 |
+
"filter_length": 1280,
|
32 |
+
"hop_length": 320,
|
33 |
+
"win_length": 1280,
|
34 |
+
"n_mel_channels": 80,
|
35 |
+
"mel_fmin": 0,
|
36 |
+
"mel_fmax": 8000,
|
37 |
+
"add_blank": true,
|
38 |
+
"n_speakers": 0,
|
39 |
+
"cleaned_text": true,
|
40 |
+
"aug_rate": 1.0,
|
41 |
+
"top_db": 20
|
42 |
+
},
|
43 |
+
"model": {
|
44 |
+
"inter_channels": 192,
|
45 |
+
"hidden_channels": 192,
|
46 |
+
"filter_channels": 768,
|
47 |
+
"n_heads": 2,
|
48 |
+
"n_layers": 6,
|
49 |
+
"kernel_size": 3,
|
50 |
+
"p_dropout": 0.1,
|
51 |
+
"resblock": "1",
|
52 |
+
"resblock_kernel_sizes": [3,7,11],
|
53 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
54 |
+
"upsample_rates": [4,5,4,2,2],
|
55 |
+
"upsample_initial_channel": 512,
|
56 |
+
"upsample_kernel_sizes": [8,11,8,4,4],
|
57 |
+
"n_layers_q": 3,
|
58 |
+
"use_spectral_norm": false,
|
59 |
+
"wav2vec_feature_layer": 7,
|
60 |
+
"gin_channels": 256,
|
61 |
+
"cfg": true,
|
62 |
+
"uncond_ratio": 0.1,
|
63 |
+
"prosody_size": 20
|
64 |
+
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
logs/hierspeechpp_eng_kor/hierspeechpp_v1_ckpt.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0229501cd273135bce9ee0be09ddef18e079cb8dbe61cda665fec73afbfad0d4
|
3 |
+
size 388970075
|
logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54d2aff8f92ff66669ef8a48ee2b604b52190fa2c421e8a20c39596cdebf9c43
|
3 |
+
size 388970075
|
logs/ttv_libritts_v1/config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"eval_interval": 10000,
|
5 |
+
"save_interval": 10000,
|
6 |
+
"seed": 1234,
|
7 |
+
"epochs": 20000,
|
8 |
+
"learning_rate": 2e-4,
|
9 |
+
"betas": [0.8, 0.99],
|
10 |
+
"eps": 1e-9,
|
11 |
+
"batch_size": 32,
|
12 |
+
"fp16_run": false,
|
13 |
+
"lr_decay": 0.999,
|
14 |
+
"segment_size": 192000,
|
15 |
+
"init_lr_ratio": 1,
|
16 |
+
"warmup_epochs": 0,
|
17 |
+
"c_mel": 10,
|
18 |
+
"c_kl": 0.1,
|
19 |
+
"c_pho": 45.0,
|
20 |
+
"c_f0": 1,
|
21 |
+
"aug": true
|
22 |
+
},
|
23 |
+
"data": {
|
24 |
+
"train_filelist_path": "filelists/train_wav.txt",
|
25 |
+
"test_filelist_path": "filelists/test_wav.txt",
|
26 |
+
"text_cleaners":["english_cleaners2"],
|
27 |
+
"max_wav_value": 32768.0,
|
28 |
+
"sampling_rate": 16000,
|
29 |
+
"filter_length": 1280,
|
30 |
+
"hop_length": 320,
|
31 |
+
"win_length": 1280,
|
32 |
+
"n_mel_channels": 80,
|
33 |
+
"mel_fmin": 0,
|
34 |
+
"mel_fmax": 8000,
|
35 |
+
"add_blank": true,
|
36 |
+
"n_speakers": 0,
|
37 |
+
"cleaned_text": true,
|
38 |
+
"aug_rate": 1.0,
|
39 |
+
"top_db": 20
|
40 |
+
},
|
41 |
+
"model": {
|
42 |
+
"inter_channels": 256,
|
43 |
+
"hidden_channels": 256,
|
44 |
+
"filter_channels": 1024,
|
45 |
+
"n_heads": 4,
|
46 |
+
"n_layers": 6,
|
47 |
+
"kernel_size": 3,
|
48 |
+
"p_dropout": 0.1,
|
49 |
+
"resblock": "1",
|
50 |
+
"resblock_kernel_sizes": [3,7,11],
|
51 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
52 |
+
"use_spectral_norm": false
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
logs/ttv_libritts_v1/ttv_lt960_ckpt.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e408e2c6a72551da0ca616f43e0f989449ff5e0d0ca6229b27dea3dc36aa6c1
|
3 |
+
size 434482782
|
modules.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from torch.nn import Conv1d
|
7 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
8 |
+
|
9 |
+
import commons
|
10 |
+
from commons import init_weights, get_padding
|
11 |
+
from transforms import piecewise_rational_quadratic_transform
|
12 |
+
|
13 |
+
from timm.models.vision_transformer import Attention
|
14 |
+
from itertools import repeat
|
15 |
+
import collections.abc
|
16 |
+
|
17 |
+
LRELU_SLOPE = 0.1
|
18 |
+
|
19 |
+
class LayerNorm(nn.Module):
|
20 |
+
def __init__(self, channels, eps=1e-5):
|
21 |
+
super().__init__()
|
22 |
+
self.channels = channels
|
23 |
+
self.eps = eps
|
24 |
+
|
25 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
26 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = x.transpose(1, -1)
|
30 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
31 |
+
return x.transpose(1, -1)
|
32 |
+
|
33 |
+
|
34 |
+
class ConvReluNorm(nn.Module):
|
35 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
36 |
+
super().__init__()
|
37 |
+
self.in_channels = in_channels
|
38 |
+
self.hidden_channels = hidden_channels
|
39 |
+
self.out_channels = out_channels
|
40 |
+
self.kernel_size = kernel_size
|
41 |
+
self.n_layers = n_layers
|
42 |
+
self.p_dropout = p_dropout
|
43 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
44 |
+
|
45 |
+
self.conv_layers = nn.ModuleList()
|
46 |
+
self.norm_layers = nn.ModuleList()
|
47 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
48 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
49 |
+
self.relu_drop = nn.Sequential(
|
50 |
+
nn.ReLU(),
|
51 |
+
nn.Dropout(p_dropout))
|
52 |
+
for _ in range(n_layers - 1):
|
53 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
54 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
55 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
56 |
+
self.proj.weight.data.zero_()
|
57 |
+
self.proj.bias.data.zero_()
|
58 |
+
|
59 |
+
def forward(self, x, x_mask):
|
60 |
+
x_org = x
|
61 |
+
for i in range(self.n_layers):
|
62 |
+
x = self.conv_layers[i](x * x_mask)
|
63 |
+
x = self.norm_layers[i](x)
|
64 |
+
x = self.relu_drop(x)
|
65 |
+
x = x_org + self.proj(x)
|
66 |
+
return x * x_mask
|
67 |
+
|
68 |
+
|
69 |
+
class DDSConv(nn.Module):
|
70 |
+
"""
|
71 |
+
Dialted and Depth-Separable Convolution
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
75 |
+
super().__init__()
|
76 |
+
self.channels = channels
|
77 |
+
self.kernel_size = kernel_size
|
78 |
+
self.n_layers = n_layers
|
79 |
+
self.p_dropout = p_dropout
|
80 |
+
|
81 |
+
self.drop = nn.Dropout(p_dropout)
|
82 |
+
self.convs_sep = nn.ModuleList()
|
83 |
+
self.convs_1x1 = nn.ModuleList()
|
84 |
+
self.norms_1 = nn.ModuleList()
|
85 |
+
self.norms_2 = nn.ModuleList()
|
86 |
+
for i in range(n_layers):
|
87 |
+
dilation = kernel_size ** i
|
88 |
+
padding = (kernel_size * dilation - dilation) // 2
|
89 |
+
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
90 |
+
groups=channels, dilation=dilation, padding=padding
|
91 |
+
))
|
92 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
93 |
+
self.norms_1.append(LayerNorm(channels))
|
94 |
+
self.norms_2.append(LayerNorm(channels))
|
95 |
+
|
96 |
+
def forward(self, x, x_mask, g=None):
|
97 |
+
if g is not None:
|
98 |
+
x = x + g
|
99 |
+
for i in range(self.n_layers):
|
100 |
+
y = self.convs_sep[i](x * x_mask)
|
101 |
+
y = self.norms_1[i](y)
|
102 |
+
y = F.gelu(y)
|
103 |
+
y = self.convs_1x1[i](y)
|
104 |
+
y = self.norms_2[i](y)
|
105 |
+
y = F.gelu(y)
|
106 |
+
y = self.drop(y)
|
107 |
+
x = x + y
|
108 |
+
return x * x_mask
|
109 |
+
|
110 |
+
|
111 |
+
class WN(torch.nn.Module):
|
112 |
+
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
113 |
+
super(WN, self).__init__()
|
114 |
+
assert (kernel_size % 2 == 1)
|
115 |
+
self.hidden_channels = hidden_channels
|
116 |
+
self.kernel_size = kernel_size,
|
117 |
+
self.dilation_rate = dilation_rate
|
118 |
+
self.n_layers = n_layers
|
119 |
+
self.gin_channels = gin_channels
|
120 |
+
self.p_dropout = p_dropout
|
121 |
+
|
122 |
+
self.in_layers = torch.nn.ModuleList()
|
123 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
124 |
+
self.drop = nn.Dropout(p_dropout)
|
125 |
+
|
126 |
+
if gin_channels != 0:
|
127 |
+
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
128 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
129 |
+
|
130 |
+
for i in range(n_layers):
|
131 |
+
dilation = dilation_rate ** i
|
132 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
133 |
+
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
|
134 |
+
dilation=dilation, padding=padding)
|
135 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
136 |
+
self.in_layers.append(in_layer)
|
137 |
+
|
138 |
+
# last one is not necessary
|
139 |
+
if i < n_layers - 1:
|
140 |
+
res_skip_channels = 2 * hidden_channels
|
141 |
+
else:
|
142 |
+
res_skip_channels = hidden_channels
|
143 |
+
|
144 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
145 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
146 |
+
self.res_skip_layers.append(res_skip_layer)
|
147 |
+
|
148 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
149 |
+
output = torch.zeros_like(x)
|
150 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
151 |
+
|
152 |
+
if g is not None:
|
153 |
+
g = self.cond_layer(g)
|
154 |
+
|
155 |
+
for i in range(self.n_layers):
|
156 |
+
x_in = self.in_layers[i](x)
|
157 |
+
if g is not None:
|
158 |
+
cond_offset = i * 2 * self.hidden_channels
|
159 |
+
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
160 |
+
else:
|
161 |
+
g_l = torch.zeros_like(x_in)
|
162 |
+
|
163 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(
|
164 |
+
x_in,
|
165 |
+
g_l,
|
166 |
+
n_channels_tensor)
|
167 |
+
acts = self.drop(acts)
|
168 |
+
|
169 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
170 |
+
if i < self.n_layers - 1:
|
171 |
+
res_acts = res_skip_acts[:, :self.hidden_channels, :]
|
172 |
+
x = (x + res_acts) * x_mask
|
173 |
+
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
174 |
+
else:
|
175 |
+
output = output + res_skip_acts
|
176 |
+
return output * x_mask
|
177 |
+
|
178 |
+
def remove_weight_norm(self):
|
179 |
+
if self.gin_channels != 0:
|
180 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
181 |
+
for l in self.in_layers:
|
182 |
+
torch.nn.utils.remove_weight_norm(l)
|
183 |
+
for l in self.res_skip_layers:
|
184 |
+
torch.nn.utils.remove_weight_norm(l)
|
185 |
+
|
186 |
+
|
187 |
+
class ResBlock1(torch.nn.Module):
|
188 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
189 |
+
super(ResBlock1, self).__init__()
|
190 |
+
self.convs1 = nn.ModuleList([
|
191 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
192 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
193 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
194 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
195 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
196 |
+
padding=get_padding(kernel_size, dilation[2])))
|
197 |
+
])
|
198 |
+
self.convs1.apply(init_weights)
|
199 |
+
|
200 |
+
self.convs2 = nn.ModuleList([
|
201 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
202 |
+
padding=get_padding(kernel_size, 1))),
|
203 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
204 |
+
padding=get_padding(kernel_size, 1))),
|
205 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
206 |
+
padding=get_padding(kernel_size, 1)))
|
207 |
+
])
|
208 |
+
self.convs2.apply(init_weights)
|
209 |
+
|
210 |
+
def forward(self, x, x_mask=None):
|
211 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
212 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
213 |
+
if x_mask is not None:
|
214 |
+
xt = xt * x_mask
|
215 |
+
xt = c1(xt)
|
216 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
217 |
+
if x_mask is not None:
|
218 |
+
xt = xt * x_mask
|
219 |
+
xt = c2(xt)
|
220 |
+
x = xt + x
|
221 |
+
if x_mask is not None:
|
222 |
+
x = x * x_mask
|
223 |
+
return x
|
224 |
+
|
225 |
+
def remove_weight_norm(self):
|
226 |
+
for l in self.convs1:
|
227 |
+
remove_weight_norm(l)
|
228 |
+
for l in self.convs2:
|
229 |
+
remove_weight_norm(l)
|
230 |
+
|
231 |
+
|
232 |
+
class ResBlock2(torch.nn.Module):
|
233 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
234 |
+
super(ResBlock2, self).__init__()
|
235 |
+
self.convs = nn.ModuleList([
|
236 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
237 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
238 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
239 |
+
padding=get_padding(kernel_size, dilation[1])))
|
240 |
+
])
|
241 |
+
self.convs.apply(init_weights)
|
242 |
+
|
243 |
+
def forward(self, x, x_mask=None):
|
244 |
+
for c in self.convs:
|
245 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
246 |
+
if x_mask is not None:
|
247 |
+
xt = xt * x_mask
|
248 |
+
xt = c(xt)
|
249 |
+
x = xt + x
|
250 |
+
if x_mask is not None:
|
251 |
+
x = x * x_mask
|
252 |
+
return x
|
253 |
+
|
254 |
+
def remove_weight_norm(self):
|
255 |
+
for l in self.convs:
|
256 |
+
remove_weight_norm(l)
|
257 |
+
|
258 |
+
|
259 |
+
class Log(nn.Module):
|
260 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
261 |
+
if not reverse:
|
262 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
263 |
+
logdet = torch.sum(-y, [1, 2])
|
264 |
+
return y, logdet
|
265 |
+
else:
|
266 |
+
x = torch.exp(x) * x_mask
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
class Flip(nn.Module):
|
271 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
272 |
+
x = torch.flip(x, [1])
|
273 |
+
if not reverse:
|
274 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
275 |
+
return x, logdet
|
276 |
+
else:
|
277 |
+
return x
|
278 |
+
|
279 |
+
|
280 |
+
class ElementwiseAffine(nn.Module):
|
281 |
+
def __init__(self, channels):
|
282 |
+
super().__init__()
|
283 |
+
self.channels = channels
|
284 |
+
self.m = nn.Parameter(torch.zeros(channels, 1))
|
285 |
+
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
286 |
+
|
287 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
288 |
+
if not reverse:
|
289 |
+
y = self.m + torch.exp(self.logs) * x
|
290 |
+
y = y * x_mask
|
291 |
+
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
292 |
+
return y, logdet
|
293 |
+
else:
|
294 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
class ResidualCouplingLayer(nn.Module):
|
299 |
+
def __init__(self,
|
300 |
+
channels,
|
301 |
+
hidden_channels,
|
302 |
+
kernel_size,
|
303 |
+
dilation_rate,
|
304 |
+
n_layers,
|
305 |
+
p_dropout=0,
|
306 |
+
gin_channels=0,
|
307 |
+
mean_only=False):
|
308 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
309 |
+
super().__init__()
|
310 |
+
self.channels = channels
|
311 |
+
self.hidden_channels = hidden_channels
|
312 |
+
self.kernel_size = kernel_size
|
313 |
+
self.dilation_rate = dilation_rate
|
314 |
+
self.n_layers = n_layers
|
315 |
+
self.half_channels = channels // 2
|
316 |
+
self.mean_only = mean_only
|
317 |
+
|
318 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
319 |
+
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
|
320 |
+
gin_channels=gin_channels)
|
321 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
322 |
+
self.post.weight.data.zero_()
|
323 |
+
self.post.bias.data.zero_()
|
324 |
+
|
325 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
326 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
327 |
+
h = self.pre(x0) * x_mask
|
328 |
+
h = self.enc(h, x_mask, g=g)
|
329 |
+
stats = self.post(h) * x_mask
|
330 |
+
if not self.mean_only:
|
331 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
332 |
+
else:
|
333 |
+
m = stats
|
334 |
+
logs = torch.zeros_like(m)
|
335 |
+
|
336 |
+
if not reverse:
|
337 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
338 |
+
x = torch.cat([x0, x1], 1)
|
339 |
+
logdet = torch.sum(logs, [1, 2])
|
340 |
+
return x, logdet
|
341 |
+
else:
|
342 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
343 |
+
x = torch.cat([x0, x1], 1)
|
344 |
+
return x
|
345 |
+
|
346 |
+
def modulate(x, shift, scale):
|
347 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
348 |
+
|
349 |
+
def _ntuple(n):
|
350 |
+
def parse(x):
|
351 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
352 |
+
return tuple(x)
|
353 |
+
return tuple(repeat(x, n))
|
354 |
+
return parse
|
355 |
+
to_2tuple = _ntuple(2)
|
356 |
+
|
357 |
+
class FFN_Conv(nn.Module):
|
358 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
359 |
+
"""
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
in_features,
|
363 |
+
hidden_features=None,
|
364 |
+
out_features=None,
|
365 |
+
act_layer=nn.GELU,
|
366 |
+
norm_layer=None,
|
367 |
+
bias=True,
|
368 |
+
kernel=5,
|
369 |
+
p_dropout=0.1
|
370 |
+
):
|
371 |
+
super().__init__()
|
372 |
+
out_features = out_features or in_features
|
373 |
+
hidden_features = hidden_features or in_features
|
374 |
+
bias = to_2tuple(bias)
|
375 |
+
|
376 |
+
self.fc1 = nn.Conv1d(in_features, hidden_features, kernel_size=kernel, stride=1, padding=(kernel-1)//2, bias=bias[0])
|
377 |
+
self.act = act_layer()
|
378 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
379 |
+
self.fc2 = nn.Conv1d(hidden_features, out_features, kernel_size=1, bias=bias[1])
|
380 |
+
self.drop = nn.Dropout(p_dropout)
|
381 |
+
|
382 |
+
def forward(self, x, x_mask):
|
383 |
+
x = self.fc1(x.transpose(1,2))
|
384 |
+
x = self.act(x)
|
385 |
+
x = self.drop(x)
|
386 |
+
x = self.fc2(x*x_mask) * x_mask
|
387 |
+
x = self.drop(x)
|
388 |
+
return x.transpose(1,2)
|
389 |
+
|
390 |
+
class DiTConVBlock(nn.Module):
|
391 |
+
"""
|
392 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
393 |
+
"""
|
394 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, kernel=9, p_dropout=0.1, **block_kwargs):
|
395 |
+
super().__init__()
|
396 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
397 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
398 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
399 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
400 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
401 |
+
self.mlp = FFN_Conv(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, kernel=kernel, p_dropout=p_dropout)
|
402 |
+
self.adaLN_modulation = nn.Sequential(
|
403 |
+
nn.SiLU(),
|
404 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
405 |
+
)
|
406 |
+
def forward(self, x, c, x_mask):
|
407 |
+
x = x*x_mask
|
408 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
409 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x)*x_mask, shift_msa, scale_msa))*x_mask
|
410 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), x_mask.transpose(1,2))
|
411 |
+
return x
|
412 |
+
|
413 |
+
class ResidualCouplingLayer_Transformer_simple(nn.Module):
|
414 |
+
def __init__(self,
|
415 |
+
channels,
|
416 |
+
hidden_channels,
|
417 |
+
kernel_size,
|
418 |
+
dilation_rate,
|
419 |
+
n_layers,
|
420 |
+
p_dropout=0.1,
|
421 |
+
mean_only=False):
|
422 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
423 |
+
super().__init__()
|
424 |
+
self.channels = channels
|
425 |
+
self.hidden_channels = hidden_channels
|
426 |
+
self.kernel_size = kernel_size
|
427 |
+
self.dilation_rate = dilation_rate
|
428 |
+
self.n_layers = n_layers
|
429 |
+
self.half_channels = channels // 2
|
430 |
+
self.mean_only = mean_only
|
431 |
+
|
432 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
433 |
+
|
434 |
+
self.enc_block = torch.nn.ModuleList([
|
435 |
+
DiTConVBlock(hidden_channels, 2, mlp_ratio=4.0, kernel=5, p_dropout=p_dropout) for _ in range(n_layers)
|
436 |
+
])
|
437 |
+
|
438 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
439 |
+
|
440 |
+
self.initialize_weights()
|
441 |
+
|
442 |
+
self.post.weight.data.zero_()
|
443 |
+
self.post.bias.data.zero_()
|
444 |
+
|
445 |
+
def initialize_weights(self):
|
446 |
+
# Initialize transformer layers:
|
447 |
+
def _basic_init(module):
|
448 |
+
if isinstance(module, (nn.Conv1d, nn.Linear)):
|
449 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
450 |
+
if module.bias is not None:
|
451 |
+
nn.init.constant_(module.bias, 0)
|
452 |
+
|
453 |
+
self.apply(_basic_init)
|
454 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
455 |
+
for block in self.enc_block:
|
456 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
457 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
458 |
+
|
459 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
460 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
461 |
+
h = self.pre(x0) * x_mask
|
462 |
+
|
463 |
+
# h = self.enc(h, x_mask, g=g)
|
464 |
+
h = h.transpose(1,2)
|
465 |
+
x_mask = x_mask.transpose(1,2)
|
466 |
+
|
467 |
+
for blk in self.enc_block:
|
468 |
+
h = blk(h, g, x_mask)
|
469 |
+
|
470 |
+
x_mask = x_mask.transpose(1,2)
|
471 |
+
h = h.transpose(1,2)
|
472 |
+
|
473 |
+
stats = self.post(h) * x_mask
|
474 |
+
if not self.mean_only:
|
475 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
476 |
+
else:
|
477 |
+
m = stats
|
478 |
+
logs = torch.zeros_like(m)
|
479 |
+
|
480 |
+
if not reverse:
|
481 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
482 |
+
x = torch.cat([x0, x1], 1)
|
483 |
+
logdet = torch.sum(logs, [1, 2])
|
484 |
+
return x, logdet
|
485 |
+
else:
|
486 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
487 |
+
x = torch.cat([x0, x1], 1)
|
488 |
+
return x
|
489 |
+
|
490 |
+
class ConvFlow(nn.Module):
|
491 |
+
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
|
492 |
+
super().__init__()
|
493 |
+
self.in_channels = in_channels
|
494 |
+
self.filter_channels = filter_channels
|
495 |
+
self.kernel_size = kernel_size
|
496 |
+
self.n_layers = n_layers
|
497 |
+
self.num_bins = num_bins
|
498 |
+
self.tail_bound = tail_bound
|
499 |
+
self.half_channels = in_channels // 2
|
500 |
+
|
501 |
+
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
502 |
+
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
|
503 |
+
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
504 |
+
self.proj.weight.data.zero_()
|
505 |
+
self.proj.bias.data.zero_()
|
506 |
+
|
507 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
508 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
509 |
+
h = self.pre(x0)
|
510 |
+
h = self.convs(h, x_mask, g=g)
|
511 |
+
h = self.proj(h) * x_mask
|
512 |
+
|
513 |
+
b, c, t = x0.shape
|
514 |
+
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
515 |
+
|
516 |
+
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
|
517 |
+
unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels)
|
518 |
+
unnormalized_derivatives = h[..., 2 * self.num_bins:]
|
519 |
+
|
520 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
|
521 |
+
unnormalized_widths,
|
522 |
+
unnormalized_heights,
|
523 |
+
unnormalized_derivatives,
|
524 |
+
inverse=reverse,
|
525 |
+
tails='linear',
|
526 |
+
tail_bound=self.tail_bound
|
527 |
+
)
|
528 |
+
|
529 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
530 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
531 |
+
if not reverse:
|
532 |
+
return x, logdet
|
533 |
+
else:
|
534 |
+
return x
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AMFM_decompy==1.0.11
|
2 |
+
Cython==3.0.3
|
3 |
+
einops==0.7.0
|
4 |
+
joblib==1.3.2
|
5 |
+
matplotlib==3.8.1
|
6 |
+
numpy==1.26.1
|
7 |
+
pesq==0.0.4
|
8 |
+
phonemizer==3.2.1
|
9 |
+
scipy==1.11.3
|
10 |
+
timm==0.6.13
|
11 |
+
torch==1.13.1+cu117
|
12 |
+
torchaudio==0.13.1+cu117
|
13 |
+
tqdm==4.65.0
|
14 |
+
transformers==4.34.0
|
15 |
+
Unidecode==1.3.7
|
results/reference_1.wav
ADDED
Binary file (528 kB). View file
|
|
results/reference_2.wav
ADDED
Binary file (505 kB). View file
|
|
results/reference_3.wav
ADDED
Binary file (378 kB). View file
|
|
results/reference_4.wav
ADDED
Binary file (699 kB). View file
|
|
speechsr24k/G_340000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94bc84b6f8bdb375e2027bc7a3222730d29f9d4c042fd9edce0024806d5c4320
|
3 |
+
size 1715101
|
speechsr24k/config.json
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
{
|
3 |
+
"train": {
|
4 |
+
"log_interval": 200,
|
5 |
+
"eval_interval": 10000,
|
6 |
+
"save_interval": 10000,
|
7 |
+
"seed": 1234,
|
8 |
+
"epochs": 20000,
|
9 |
+
"learning_rate": 1e-4,
|
10 |
+
"betas": [0.8, 0.99],
|
11 |
+
"eps": 1e-9,
|
12 |
+
"batch_size": 32,
|
13 |
+
"fp16_run": false,
|
14 |
+
"lr_decay": 0.999,
|
15 |
+
"segment_size": 9600,
|
16 |
+
"init_lr_ratio": 1,
|
17 |
+
"warmup_epochs": 0,
|
18 |
+
"c_mel": 45
|
19 |
+
},
|
20 |
+
"data": {
|
21 |
+
"train_filelist_path": "filelists/train_24k_bigvgan_sr.txt",
|
22 |
+
"test_filelist_path": "filelists/test_24k_bigvgan_sr.txt",
|
23 |
+
"text_cleaners":["english_cleaners2"],
|
24 |
+
"max_wav_value": 32768.0,
|
25 |
+
"sampling_rate": 24000,
|
26 |
+
"filter_length": 960,
|
27 |
+
"hop_length": 240,
|
28 |
+
"win_length": 960,
|
29 |
+
"n_mel_channels": 100,
|
30 |
+
"mel_fmin": 0,
|
31 |
+
"mel_fmax": 12000,
|
32 |
+
"add_blank": true,
|
33 |
+
"n_speakers": 0,
|
34 |
+
"cleaned_text": true,
|
35 |
+
"aug_rate": 1.0,
|
36 |
+
"top_db": 20
|
37 |
+
},
|
38 |
+
"model": {
|
39 |
+
"resblock": "0",
|
40 |
+
"resblock_kernel_sizes": [3,7,11],
|
41 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
42 |
+
"upsample_rates": [3],
|
43 |
+
"upsample_initial_channel": 32,
|
44 |
+
"upsample_kernel_sizes": [3],
|
45 |
+
"use_spectral_norm": false
|
46 |
+
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
speechsr24k/speechsr.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import modules
|
5 |
+
|
6 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
7 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
8 |
+
from commons import init_weights, get_padding
|
9 |
+
from torch.cuda.amp import autocast
|
10 |
+
import torchaudio
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
from alias_free_torch import *
|
14 |
+
import activations
|
15 |
+
|
16 |
+
class AMPBlock0(torch.nn.Module):
|
17 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
18 |
+
super(AMPBlock0, self).__init__()
|
19 |
+
|
20 |
+
self.convs1 = nn.ModuleList([
|
21 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
22 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
23 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
24 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
25 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
26 |
+
padding=get_padding(kernel_size, dilation[2]))),
|
27 |
+
])
|
28 |
+
self.convs1.apply(init_weights)
|
29 |
+
|
30 |
+
self.convs2 = nn.ModuleList([
|
31 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
32 |
+
padding=get_padding(kernel_size, 1))),
|
33 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
34 |
+
padding=get_padding(kernel_size, 1))),
|
35 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
36 |
+
padding=get_padding(kernel_size, 1))),
|
37 |
+
])
|
38 |
+
self.convs2.apply(init_weights)
|
39 |
+
|
40 |
+
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
41 |
+
|
42 |
+
|
43 |
+
self.activations = nn.ModuleList([
|
44 |
+
Activation1d(
|
45 |
+
activation=activations.SnakeBeta(channels, alpha_logscale=True))
|
46 |
+
for _ in range(self.num_layers)
|
47 |
+
])
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
51 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
52 |
+
xt = a1(x)
|
53 |
+
xt = c1(xt)
|
54 |
+
xt = a2(xt)
|
55 |
+
xt = c2(xt)
|
56 |
+
x = xt + x
|
57 |
+
|
58 |
+
return x
|
59 |
+
|
60 |
+
def remove_weight_norm(self):
|
61 |
+
for l in self.convs1:
|
62 |
+
remove_weight_norm(l)
|
63 |
+
for l in self.convs2:
|
64 |
+
remove_weight_norm(l)
|
65 |
+
|
66 |
+
|
67 |
+
class Generator(torch.nn.Module):
|
68 |
+
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
69 |
+
super(Generator, self).__init__()
|
70 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
71 |
+
self.num_upsamples = len(upsample_rates)
|
72 |
+
|
73 |
+
self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
|
74 |
+
resblock = AMPBlock0
|
75 |
+
|
76 |
+
self.resblocks = nn.ModuleList()
|
77 |
+
for i in range(1):
|
78 |
+
ch = upsample_initial_channel//(2**(i))
|
79 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
80 |
+
self.resblocks.append(resblock(ch, k, d, activation="snakebeta"))
|
81 |
+
|
82 |
+
activation_post = activations.SnakeBeta(ch, alpha_logscale=True)
|
83 |
+
self.activation_post = Activation1d(activation=activation_post)
|
84 |
+
|
85 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
86 |
+
if gin_channels != 0:
|
87 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
88 |
+
|
89 |
+
def forward(self, x, g=None):
|
90 |
+
x = self.conv_pre(x)
|
91 |
+
if g is not None:
|
92 |
+
x = x + self.cond(g)
|
93 |
+
|
94 |
+
for i in range(self.num_upsamples):
|
95 |
+
|
96 |
+
x = F.interpolate(x, int(x.shape[-1] * 1.5), mode='linear')
|
97 |
+
xs = None
|
98 |
+
for j in range(self.num_kernels):
|
99 |
+
if xs is None:
|
100 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
101 |
+
else:
|
102 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
103 |
+
x = xs / self.num_kernels
|
104 |
+
|
105 |
+
x = self.activation_post(x)
|
106 |
+
x = self.conv_post(x)
|
107 |
+
x = torch.tanh(x)
|
108 |
+
|
109 |
+
return x
|
110 |
+
|
111 |
+
def remove_weight_norm(self):
|
112 |
+
print('Removing weight norm...')
|
113 |
+
for l in self.resblocks:
|
114 |
+
l.remove_weight_norm()
|
115 |
+
remove_weight_norm(self.conv_pre)
|
116 |
+
|
117 |
+
class DiscriminatorP(torch.nn.Module):
|
118 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
119 |
+
super(DiscriminatorP, self).__init__()
|
120 |
+
self.period = period
|
121 |
+
self.use_spectral_norm = use_spectral_norm
|
122 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
123 |
+
self.convs = nn.ModuleList([
|
124 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
125 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
126 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
127 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
128 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
129 |
+
])
|
130 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
fmap = []
|
134 |
+
|
135 |
+
# 1d to 2d
|
136 |
+
b, c, t = x.shape
|
137 |
+
if t % self.period != 0: # pad first
|
138 |
+
n_pad = self.period - (t % self.period)
|
139 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
140 |
+
t = t + n_pad
|
141 |
+
x = x.view(b, c, t // self.period, self.period)
|
142 |
+
|
143 |
+
for l in self.convs:
|
144 |
+
x = l(x)
|
145 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
146 |
+
fmap.append(x)
|
147 |
+
x = self.conv_post(x)
|
148 |
+
fmap.append(x)
|
149 |
+
x = torch.flatten(x, 1, -1)
|
150 |
+
|
151 |
+
return x, fmap
|
152 |
+
|
153 |
+
class DiscriminatorR(torch.nn.Module):
|
154 |
+
def __init__(self, resolution, use_spectral_norm=False):
|
155 |
+
super(DiscriminatorR, self).__init__()
|
156 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
157 |
+
|
158 |
+
n_fft, hop_length, win_length = resolution
|
159 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
160 |
+
n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window,
|
161 |
+
normalized=True, center=False, pad_mode=None, power=None)
|
162 |
+
|
163 |
+
self.convs = nn.ModuleList([
|
164 |
+
norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))),
|
165 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
166 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))),
|
167 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))),
|
168 |
+
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
169 |
+
])
|
170 |
+
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
171 |
+
|
172 |
+
def forward(self, y):
|
173 |
+
fmap = []
|
174 |
+
|
175 |
+
x = self.spec_transform(y) # [B, 2, Freq, Frames, 2]
|
176 |
+
x = torch.cat([x.real, x.imag], dim=1)
|
177 |
+
x = rearrange(x, 'b c w t -> b c t w')
|
178 |
+
|
179 |
+
for l in self.convs:
|
180 |
+
x = l(x)
|
181 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
182 |
+
fmap.append(x)
|
183 |
+
x = self.conv_post(x)
|
184 |
+
fmap.append(x)
|
185 |
+
x = torch.flatten(x, 1, -1)
|
186 |
+
|
187 |
+
return x, fmap
|
188 |
+
|
189 |
+
|
190 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
191 |
+
def __init__(self, use_spectral_norm=False):
|
192 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
193 |
+
periods = [2,3,5,7,11]
|
194 |
+
resolutions = [[2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]]
|
195 |
+
|
196 |
+
discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))]
|
197 |
+
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
198 |
+
self.discriminators = nn.ModuleList(discs)
|
199 |
+
|
200 |
+
def forward(self, y, y_hat):
|
201 |
+
y_d_rs = []
|
202 |
+
y_d_gs = []
|
203 |
+
fmap_rs = []
|
204 |
+
fmap_gs = []
|
205 |
+
for i, d in enumerate(self.discriminators):
|
206 |
+
y_d_r, fmap_r = d(y)
|
207 |
+
y_d_g, fmap_g = d(y_hat)
|
208 |
+
y_d_rs.append(y_d_r)
|
209 |
+
y_d_gs.append(y_d_g)
|
210 |
+
fmap_rs.append(fmap_r)
|
211 |
+
fmap_gs.append(fmap_g)
|
212 |
+
|
213 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
214 |
+
|
215 |
+
class SynthesizerTrn(nn.Module):
|
216 |
+
"""
|
217 |
+
Synthesizer for Training
|
218 |
+
"""
|
219 |
+
|
220 |
+
def __init__(self,
|
221 |
+
|
222 |
+
spec_channels,
|
223 |
+
segment_size,
|
224 |
+
resblock,
|
225 |
+
resblock_kernel_sizes,
|
226 |
+
resblock_dilation_sizes,
|
227 |
+
upsample_rates,
|
228 |
+
upsample_initial_channel,
|
229 |
+
upsample_kernel_sizes,
|
230 |
+
**kwargs):
|
231 |
+
|
232 |
+
super().__init__()
|
233 |
+
self.spec_channels = spec_channels
|
234 |
+
self.resblock = resblock
|
235 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
236 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
237 |
+
self.upsample_rates = upsample_rates
|
238 |
+
self.upsample_initial_channel = upsample_initial_channel
|
239 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
240 |
+
self.segment_size = segment_size
|
241 |
+
|
242 |
+
self.dec = Generator(1, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes)
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
|
246 |
+
y = self.dec(x)
|
247 |
+
return y
|
248 |
+
@torch.no_grad()
|
249 |
+
def infer(self, x, max_len=None):
|
250 |
+
|
251 |
+
o = self.dec(x[:,:,:max_len])
|
252 |
+
return o
|
253 |
+
|
speechsr48k/G_100000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62c70874ac4efeb4dc9c8aa9dc0a611a951e1c36292abeb4c406d7fb91e0eefc
|
3 |
+
size 1715101
|
speechsr48k/config.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
{
|
3 |
+
"train": {
|
4 |
+
"log_interval": 200,
|
5 |
+
"eval_interval": 10000,
|
6 |
+
"save_interval": 10000,
|
7 |
+
"seed": 1234,
|
8 |
+
"epochs": 20000,
|
9 |
+
"learning_rate": 1e-4,
|
10 |
+
"betas": [0.8, 0.99],
|
11 |
+
"eps": 1e-9,
|
12 |
+
"batch_size": 32,
|
13 |
+
"fp16_run": false,
|
14 |
+
"lr_decay": 0.995,
|
15 |
+
"segment_size": 9600,
|
16 |
+
"init_lr_ratio": 1,
|
17 |
+
"warmup_epochs": 0,
|
18 |
+
"c_mel": 45
|
19 |
+
},
|
20 |
+
"data": {
|
21 |
+
"train_filelist_path": "filelists/train_48k_vctk_trim_bigvgan_sr.txt",
|
22 |
+
"test_filelist_path": "filelists/test_48k_vctk_trim_bigvgan_sr.txt",
|
23 |
+
"text_cleaners":["english_cleaners2"],
|
24 |
+
"max_wav_value": 32768.0,
|
25 |
+
"sampling_rate": 48000,
|
26 |
+
"filter_length": 1280,
|
27 |
+
"hop_length": 320,
|
28 |
+
"win_length": 1280,
|
29 |
+
"n_mel_channels": 128,
|
30 |
+
"mel_fmin": 0,
|
31 |
+
"mel_fmax": 24000,
|
32 |
+
"add_blank": true,
|
33 |
+
"n_speakers": 0,
|
34 |
+
"cleaned_text": true,
|
35 |
+
"aug_rate": 1.0,
|
36 |
+
"top_db": 20
|
37 |
+
},
|
38 |
+
"model": {
|
39 |
+
"resblock": "0",
|
40 |
+
"resblock_kernel_sizes": [3,7,11],
|
41 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
42 |
+
"upsample_rates": [3],
|
43 |
+
"upsample_initial_channel": 32,
|
44 |
+
"upsample_kernel_sizes": [3],
|
45 |
+
"use_spectral_norm": false
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
speechsr48k/speechsr.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import modules
|
5 |
+
|
6 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
7 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
8 |
+
from commons import init_weights, get_padding
|
9 |
+
from torch.cuda.amp import autocast
|
10 |
+
import torchaudio
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
from alias_free_torch import *
|
14 |
+
import activations
|
15 |
+
|
16 |
+
class AMPBlock0(torch.nn.Module):
|
17 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
18 |
+
super(AMPBlock0, self).__init__()
|
19 |
+
|
20 |
+
self.convs1 = nn.ModuleList([
|
21 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
22 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
23 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
24 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
25 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
26 |
+
padding=get_padding(kernel_size, dilation[2]))),
|
27 |
+
])
|
28 |
+
self.convs1.apply(init_weights)
|
29 |
+
|
30 |
+
self.convs2 = nn.ModuleList([
|
31 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
32 |
+
padding=get_padding(kernel_size, 1))),
|
33 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
34 |
+
padding=get_padding(kernel_size, 1))),
|
35 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
36 |
+
padding=get_padding(kernel_size, 1))),
|
37 |
+
])
|
38 |
+
self.convs2.apply(init_weights)
|
39 |
+
|
40 |
+
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
41 |
+
|
42 |
+
|
43 |
+
self.activations = nn.ModuleList([
|
44 |
+
Activation1d(
|
45 |
+
activation=activations.SnakeBeta(channels, alpha_logscale=True))
|
46 |
+
for _ in range(self.num_layers)
|
47 |
+
])
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
51 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
52 |
+
xt = a1(x)
|
53 |
+
xt = c1(xt)
|
54 |
+
xt = a2(xt)
|
55 |
+
xt = c2(xt)
|
56 |
+
x = xt + x
|
57 |
+
|
58 |
+
return x
|
59 |
+
|
60 |
+
def remove_weight_norm(self):
|
61 |
+
for l in self.convs1:
|
62 |
+
remove_weight_norm(l)
|
63 |
+
for l in self.convs2:
|
64 |
+
remove_weight_norm(l)
|
65 |
+
|
66 |
+
|
67 |
+
class Generator(torch.nn.Module):
|
68 |
+
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
69 |
+
super(Generator, self).__init__()
|
70 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
71 |
+
self.num_upsamples = len(upsample_rates)
|
72 |
+
|
73 |
+
self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
|
74 |
+
resblock = AMPBlock0
|
75 |
+
|
76 |
+
self.resblocks = nn.ModuleList()
|
77 |
+
for i in range(1):
|
78 |
+
ch = upsample_initial_channel//(2**(i))
|
79 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
80 |
+
self.resblocks.append(resblock(ch, k, d, activation="snakebeta"))
|
81 |
+
|
82 |
+
activation_post = activations.SnakeBeta(ch, alpha_logscale=True)
|
83 |
+
self.activation_post = Activation1d(activation=activation_post)
|
84 |
+
|
85 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
86 |
+
if gin_channels != 0:
|
87 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
88 |
+
|
89 |
+
def forward(self, x, g=None):
|
90 |
+
x = self.conv_pre(x)
|
91 |
+
if g is not None:
|
92 |
+
x = x + self.cond(g)
|
93 |
+
|
94 |
+
for i in range(self.num_upsamples):
|
95 |
+
|
96 |
+
x = F.interpolate(x, int(x.shape[-1] * 3), mode='linear')
|
97 |
+
xs = None
|
98 |
+
for j in range(self.num_kernels):
|
99 |
+
if xs is None:
|
100 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
101 |
+
else:
|
102 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
103 |
+
x = xs / self.num_kernels
|
104 |
+
|
105 |
+
x = self.activation_post(x)
|
106 |
+
x = self.conv_post(x)
|
107 |
+
x = torch.tanh(x)
|
108 |
+
|
109 |
+
return x
|
110 |
+
|
111 |
+
def remove_weight_norm(self):
|
112 |
+
print('Removing weight norm...')
|
113 |
+
for l in self.resblocks:
|
114 |
+
l.remove_weight_norm()
|
115 |
+
|
116 |
+
class DiscriminatorP(torch.nn.Module):
|
117 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
118 |
+
super(DiscriminatorP, self).__init__()
|
119 |
+
self.period = period
|
120 |
+
self.use_spectral_norm = use_spectral_norm
|
121 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
122 |
+
self.convs = nn.ModuleList([
|
123 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
124 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
125 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
126 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
127 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
128 |
+
])
|
129 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
fmap = []
|
133 |
+
|
134 |
+
# 1d to 2d
|
135 |
+
b, c, t = x.shape
|
136 |
+
if t % self.period != 0: # pad first
|
137 |
+
n_pad = self.period - (t % self.period)
|
138 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
139 |
+
t = t + n_pad
|
140 |
+
x = x.view(b, c, t // self.period, self.period)
|
141 |
+
|
142 |
+
for l in self.convs:
|
143 |
+
x = l(x)
|
144 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
145 |
+
fmap.append(x)
|
146 |
+
x = self.conv_post(x)
|
147 |
+
fmap.append(x)
|
148 |
+
x = torch.flatten(x, 1, -1)
|
149 |
+
|
150 |
+
return x, fmap
|
151 |
+
|
152 |
+
class DiscriminatorR(torch.nn.Module):
|
153 |
+
def __init__(self, resolution, use_spectral_norm=False):
|
154 |
+
super(DiscriminatorR, self).__init__()
|
155 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
156 |
+
|
157 |
+
n_fft, hop_length, win_length = resolution
|
158 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
159 |
+
n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window,
|
160 |
+
normalized=True, center=False, pad_mode=None, power=None)
|
161 |
+
|
162 |
+
self.convs = nn.ModuleList([
|
163 |
+
norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))),
|
164 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
165 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))),
|
166 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))),
|
167 |
+
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
168 |
+
])
|
169 |
+
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
170 |
+
|
171 |
+
def forward(self, y):
|
172 |
+
fmap = []
|
173 |
+
|
174 |
+
x = self.spec_transform(y) # [B, 2, Freq, Frames, 2]
|
175 |
+
x = torch.cat([x.real, x.imag], dim=1)
|
176 |
+
x = rearrange(x, 'b c w t -> b c t w')
|
177 |
+
|
178 |
+
for l in self.convs:
|
179 |
+
x = l(x)
|
180 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
181 |
+
fmap.append(x)
|
182 |
+
x = self.conv_post(x)
|
183 |
+
fmap.append(x)
|
184 |
+
x = torch.flatten(x, 1, -1)
|
185 |
+
|
186 |
+
return x, fmap
|
187 |
+
|
188 |
+
|
189 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
190 |
+
def __init__(self, use_spectral_norm=False):
|
191 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
192 |
+
periods = [2,3,5,7,11]
|
193 |
+
resolutions = [[4096, 1024, 4096], [2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]]
|
194 |
+
|
195 |
+
discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))]
|
196 |
+
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
197 |
+
self.discriminators = nn.ModuleList(discs)
|
198 |
+
|
199 |
+
def forward(self, y, y_hat):
|
200 |
+
y_d_rs = []
|
201 |
+
y_d_gs = []
|
202 |
+
fmap_rs = []
|
203 |
+
fmap_gs = []
|
204 |
+
for i, d in enumerate(self.discriminators):
|
205 |
+
y_d_r, fmap_r = d(y)
|
206 |
+
y_d_g, fmap_g = d(y_hat)
|
207 |
+
y_d_rs.append(y_d_r)
|
208 |
+
y_d_gs.append(y_d_g)
|
209 |
+
fmap_rs.append(fmap_r)
|
210 |
+
fmap_gs.append(fmap_g)
|
211 |
+
|
212 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
213 |
+
|
214 |
+
class SynthesizerTrn(nn.Module):
|
215 |
+
"""
|
216 |
+
Synthesizer for Training
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self,
|
220 |
+
|
221 |
+
spec_channels,
|
222 |
+
segment_size,
|
223 |
+
resblock,
|
224 |
+
resblock_kernel_sizes,
|
225 |
+
resblock_dilation_sizes,
|
226 |
+
upsample_rates,
|
227 |
+
upsample_initial_channel,
|
228 |
+
upsample_kernel_sizes,
|
229 |
+
**kwargs):
|
230 |
+
|
231 |
+
super().__init__()
|
232 |
+
self.spec_channels = spec_channels
|
233 |
+
self.resblock = resblock
|
234 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
235 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
236 |
+
self.upsample_rates = upsample_rates
|
237 |
+
self.upsample_initial_channel = upsample_initial_channel
|
238 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
239 |
+
self.segment_size = segment_size
|
240 |
+
|
241 |
+
self.dec = Generator(1, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
|
245 |
+
y = self.dec(x)
|
246 |
+
return y
|
247 |
+
@torch.no_grad()
|
248 |
+
def infer(self, x, max_len=None):
|
249 |
+
|
250 |
+
o = self.dec(x[:,:,:max_len])
|
251 |
+
return o
|
252 |
+
|
styleencoder.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import attentions
|
2 |
+
from torch import nn
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
class Mish(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(Mish, self).__init__()
|
9 |
+
def forward(self, x):
|
10 |
+
return x * torch.tanh(F.softplus(x))
|
11 |
+
|
12 |
+
|
13 |
+
class Conv1dGLU(nn.Module):
|
14 |
+
'''
|
15 |
+
Conv1d + GLU(Gated Linear Unit) with residual connection.
|
16 |
+
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
|
17 |
+
'''
|
18 |
+
|
19 |
+
def __init__(self, in_channels, out_channels, kernel_size, dropout):
|
20 |
+
super(Conv1dGLU, self).__init__()
|
21 |
+
self.out_channels = out_channels
|
22 |
+
self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2)
|
23 |
+
self.dropout = nn.Dropout(dropout)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
residual = x
|
27 |
+
x = self.conv1(x)
|
28 |
+
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
|
29 |
+
x = x1 * torch.sigmoid(x2)
|
30 |
+
x = residual + self.dropout(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
class StyleEncoder(torch.nn.Module):
|
34 |
+
def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
|
35 |
+
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
|
39 |
+
self.hidden_dim = hidden_dim
|
40 |
+
self.out_dim = out_dim
|
41 |
+
self.kernel_size = 5
|
42 |
+
self.n_head = 2
|
43 |
+
self.dropout = 0.1
|
44 |
+
|
45 |
+
self.spectral = nn.Sequential(
|
46 |
+
nn.Conv1d(self.in_dim, self.hidden_dim, 1),
|
47 |
+
Mish(),
|
48 |
+
nn.Dropout(self.dropout),
|
49 |
+
nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
|
50 |
+
Mish(),
|
51 |
+
nn.Dropout(self.dropout)
|
52 |
+
)
|
53 |
+
|
54 |
+
self.temporal = nn.Sequential(
|
55 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
56 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
57 |
+
)
|
58 |
+
|
59 |
+
self.slf_attn = attentions.MultiHeadAttention(self.hidden_dim, self.hidden_dim, self.n_head, p_dropout = self.dropout, proximal_bias= False, proximal_init=True)
|
60 |
+
self.atten_drop = nn.Dropout(self.dropout)
|
61 |
+
self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
|
62 |
+
|
63 |
+
def forward(self, x, mask=None):
|
64 |
+
|
65 |
+
# spectral
|
66 |
+
x = self.spectral(x)*mask
|
67 |
+
# temporal
|
68 |
+
x = self.temporal(x)*mask
|
69 |
+
|
70 |
+
# self-attention
|
71 |
+
attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
|
72 |
+
y = self.slf_attn(x,x, attn_mask=attn_mask)
|
73 |
+
x = x+ self.atten_drop(y)
|
74 |
+
|
75 |
+
# fc
|
76 |
+
x = self.fc(x)
|
77 |
+
|
78 |
+
# temoral average pooling
|
79 |
+
w = self.temporal_avg_pool(x, mask=mask)
|
80 |
+
|
81 |
+
return w
|
82 |
+
|
83 |
+
def temporal_avg_pool(self, x, mask=None):
|
84 |
+
if mask is None:
|
85 |
+
out = torch.mean(x, dim=2)
|
86 |
+
else:
|
87 |
+
len_ = mask.sum(dim=2)
|
88 |
+
x = x.sum(dim=2)
|
89 |
+
|
90 |
+
out = torch.div(x, len_)
|
91 |
+
return out
|