BadriNarayanan commited on
Commit
9fc8907
1 Parent(s): 9f07e86

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customed
2
+ .vscode/
3
+ tests/
4
+ runs/
5
+ data/
6
+ ckpts/
7
+ wandb/
8
+ results/
9
+
10
+
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ .pybuilder/
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ # For a library or package, you might want to ignore these files since the code is
98
+ # intended to run in multiple environments; otherwise, check them in:
99
+ # .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
121
+ .pdm.toml
122
+ .pdm-python
123
+ .pdm-build/
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yushen CHEN
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,186 @@
1
- ---
2
- title: Gradio Text To Speech App
3
- emoji: ⚡
4
- colorFrom: gray
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: This TTS app uses E2-TTS model for precise voice cloning
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
+
3
+ [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
4
+ [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
5
+ [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
6
+
7
+ **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
8
+
9
+ **E2 TTS**: Flat-UNet Transformer, closest reproduction.
10
+
11
+ **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
12
+
13
+ ## Installation
14
+
15
+ Clone the repository:
16
+
17
+ ```bash
18
+ git clone https://github.com/SWivid/F5-TTS.git
19
+ cd F5-TTS
20
+ ```
21
+
22
+ Install packages:
23
+
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ Install torch with your CUDA version, e.g. :
29
+
30
+ ```bash
31
+ pip install torch==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
32
+ pip install torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
33
+ ```
34
+
35
+ ## Prepare Dataset
36
+
37
+ Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
38
+
39
+ ```bash
40
+ # prepare custom dataset up to your need
41
+ # download corresponding dataset first, and fill in the path in scripts
42
+
43
+ # Prepare the Emilia dataset
44
+ python scripts/prepare_emilia.py
45
+
46
+ # Prepare the Wenetspeech4TTS dataset
47
+ python scripts/prepare_wenetspeech4tts.py
48
+ ```
49
+
50
+ ## Training
51
+
52
+ Once your datasets are prepared, you can start the training process.
53
+
54
+ ```bash
55
+ # setup accelerate config, e.g. use multi-gpu ddp, fp16
56
+ # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
57
+ accelerate config
58
+ accelerate launch test_train.py
59
+ ```
60
+
61
+ ## Inference
62
+
63
+ To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS).
64
+
65
+ ### Single Inference
66
+
67
+ You can test single inference using the following command. Before running the command, modify the config up to your need.
68
+
69
+ ```bash
70
+ # modify the config up to your need,
71
+ # e.g. fix_duration (the total length of prompt + to_generate, currently support up to 30s)
72
+ # nfe_step (larger takes more time to do more precise inference ode)
73
+ # ode_method (switch to 'midpoint' for better compatibility with small nfe_step, )
74
+ # ( though 'midpoint' is 2nd-order ode solver, slower compared to 1st-order 'Euler')
75
+ python test_infer_single.py
76
+ ```
77
+ ### Speech Editing
78
+
79
+ To test speech editing capabilities, use the following command.
80
+
81
+ ```bash
82
+ python test_infer_single_edit.py
83
+ ```
84
+
85
+ ### Gradio App
86
+
87
+ You can launch a Gradio app (web interface) to launch a GUI for inference.
88
+
89
+ First, make sure you have the dependencies installed (`pip install -r requirements.txt`). Then, install the Gradio app dependencies:
90
+
91
+ ```bash
92
+ pip install -r requirements_gradio.txt
93
+ ```
94
+
95
+ After installing the dependencies, launch the app (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`):
96
+
97
+ ```bash
98
+ python gradio_app.py
99
+ ```
100
+
101
+ You can specify the port/host:
102
+
103
+ ```bash
104
+ python gradio_app.py --port 7860 --host 0.0.0.0
105
+ ```
106
+
107
+ Or launch a share link:
108
+
109
+ ```bash
110
+ python gradio_app.py --share
111
+ ```
112
+
113
+ ## Evaluation
114
+
115
+ ### Prepare Test Datasets
116
+
117
+ 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
118
+ 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
119
+ 3. Unzip the downloaded datasets and place them in the data/ directory.
120
+ 4. Update the path for the test-clean data in `test_infer_batch.py`
121
+ 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
122
+
123
+ ### Batch Inference for Test Set
124
+
125
+ To run batch inference for evaluations, execute the following commands:
126
+
127
+ ```bash
128
+ # batch inference for evaluations
129
+ accelerate config # if not set before
130
+ bash test_infer_batch.sh
131
+ ```
132
+
133
+ ### Download Evaluation Model Checkpoints
134
+
135
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
136
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
137
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
138
+
139
+ ### Objective Evaluation
140
+
141
+ **Some Notes**
142
+
143
+ For faster-whisper with CUDA 11:
144
+
145
+ ```bash
146
+ pip install --force-reinstall ctranslate2==3.24.0
147
+ ```
148
+
149
+ (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
150
+
151
+ ```bash
152
+ pip install faster-whisper==0.10.1
153
+ ```
154
+
155
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
156
+ ```bash
157
+ # Evaluation for Seed-TTS test set
158
+ python scripts/eval_seedtts_testset.py
159
+
160
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
161
+ python scripts/eval_librispeech_test_clean.py
162
+ ```
163
+
164
+ ## Acknowledgements
165
+
166
+ - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
167
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
168
+ - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
169
+ - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
170
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
171
+ - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
172
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
173
+ - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
174
+
175
+ ## Citation
176
+ ```
177
+ @article{chen-etal-2024-f5tts,
178
+ title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
179
+ author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
180
+ journal={arXiv preprint arXiv:2410.06885},
181
+ year={2024},
182
+ }
183
+ ```
184
+ ## License
185
+
186
+ Our code is released under MIT License.
gradio_app.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import re
3
+ # import torch
4
+ # import torchaudio
5
+ # import gradio as gr
6
+ # import numpy as np
7
+ # import tempfile
8
+ # from einops import rearrange
9
+ # from vocos import Vocos
10
+ # from pydub import AudioSegment
11
+ # from model import CFM, UNetT, DiT, MMDiT
12
+ # from cached_path import cached_path
13
+ # from model.utils import (
14
+ # load_checkpoint,
15
+ # get_tokenizer,
16
+ # convert_char_to_pinyin,
17
+ # save_spectrogram,
18
+ # )
19
+ # from transformers import pipeline
20
+ # import librosa
21
+ # import click
22
+
23
+ # device = (
24
+ # "cuda"
25
+ # if torch.cuda.is_available()
26
+ # else "mps" if torch.backends.mps.is_available() else "cpu"
27
+ # )
28
+
29
+ # print(f"Using {device} device")
30
+
31
+ # pipe = pipeline(
32
+ # "automatic-speech-recognition",
33
+ # model="openai/whisper-large-v3-turbo",
34
+ # torch_dtype=torch.float16,
35
+ # device=device,
36
+ # )
37
+
38
+ # # --------------------- Settings -------------------- #
39
+
40
+ # target_sample_rate = 24000
41
+ # n_mel_channels = 100
42
+ # hop_length = 256
43
+ # target_rms = 0.1
44
+ # nfe_step = 32 # 16, 32
45
+ # cfg_strength = 2.0
46
+ # ode_method = "euler"
47
+ # sway_sampling_coef = -1.0
48
+ # speed = 1.0
49
+ # # fix_duration = 27 # None or float (duration in seconds)
50
+ # fix_duration = None
51
+
52
+
53
+ # def load_model(exp_name, model_cls, model_cfg, ckpt_step):
54
+ # ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
55
+ # # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
56
+ # vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
57
+ # model = CFM(
58
+ # transformer=model_cls(
59
+ # **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
60
+ # ),
61
+ # mel_spec_kwargs=dict(
62
+ # target_sample_rate=target_sample_rate,
63
+ # n_mel_channels=n_mel_channels,
64
+ # hop_length=hop_length,
65
+ # ),
66
+ # odeint_kwargs=dict(
67
+ # method=ode_method,
68
+ # ),
69
+ # vocab_char_map=vocab_char_map,
70
+ # ).to(device)
71
+
72
+ # model = load_checkpoint(model, ckpt_path, device, use_ema = True)
73
+
74
+ # return model
75
+
76
+
77
+ # # load models
78
+ # F5TTS_model_cfg = dict(
79
+ # dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
80
+ # )
81
+ # E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
82
+
83
+ # F5TTS_ema_model = load_model(
84
+ # "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
85
+ # )
86
+ # E2TTS_ema_model = load_model(
87
+ # "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
88
+ # )
89
+
90
+
91
+ # def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
92
+ # print(gen_text)
93
+ # if len(gen_text) > 200:
94
+ # raise gr.Error("Please keep your text under 200 chars.")
95
+ # gr.Info("Converting audio...")
96
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
97
+ # aseg = AudioSegment.from_file(ref_audio_orig)
98
+ # audio_duration = len(aseg)
99
+ # if audio_duration > 15000:
100
+ # gr.Warning("Audio is over 15s, clipping to only first 15s.")
101
+ # aseg = aseg[:15000]
102
+ # aseg.export(f.name, format="wav")
103
+ # ref_audio = f.name
104
+ # if exp_name == "F5-TTS":
105
+ # ema_model = F5TTS_ema_model
106
+ # elif exp_name == "E2-TTS":
107
+ # ema_model = E2TTS_ema_model
108
+
109
+ # if not ref_text.strip():
110
+ # gr.Info("No reference text provided, transcribing reference audio...")
111
+ # ref_text = outputs = pipe(
112
+ # ref_audio,
113
+ # chunk_length_s=30,
114
+ # batch_size=128,
115
+ # generate_kwargs={"task": "transcribe"},
116
+ # return_timestamps=False,
117
+ # )["text"].strip()
118
+ # gr.Info("Finished transcription")
119
+ # else:
120
+ # gr.Info("Using custom reference text...")
121
+ # audio, sr = torchaudio.load(ref_audio)
122
+ # if audio.shape[0] > 1:
123
+ # audio = torch.mean(audio, dim=0, keepdim=True)
124
+
125
+ # rms = torch.sqrt(torch.mean(torch.square(audio)))
126
+ # if rms < target_rms:
127
+ # audio = audio * target_rms / rms
128
+ # if sr != target_sample_rate:
129
+ # resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
130
+ # audio = resampler(audio)
131
+ # audio = audio.to(device)
132
+
133
+ # # Prepare the text
134
+ # text_list = [ref_text + gen_text]
135
+ # final_text_list = convert_char_to_pinyin(text_list)
136
+
137
+ # # Calculate duration
138
+ # ref_audio_len = audio.shape[-1] // hop_length
139
+ # # if fix_duration is not None:
140
+ # # duration = int(fix_duration * target_sample_rate / hop_length)
141
+ # # else:
142
+ # zh_pause_punc = r"。,、;:?!"
143
+ # ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
144
+ # gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
145
+ # duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
146
+
147
+ # # inference
148
+ # gr.Info(f"Generating audio using {exp_name}")
149
+ # with torch.inference_mode():
150
+ # generated, _ = ema_model.sample(
151
+ # cond=audio,
152
+ # text=final_text_list,
153
+ # duration=duration,
154
+ # steps=nfe_step,
155
+ # cfg_strength=cfg_strength,
156
+ # sway_sampling_coef=sway_sampling_coef,
157
+ # )
158
+
159
+ # generated = generated[:, ref_audio_len:, :]
160
+ # generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
161
+ # gr.Info("Running vocoder")
162
+ # vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
163
+ # generated_wave = vocos.decode(generated_mel_spec.cpu())
164
+ # if rms < target_rms:
165
+ # generated_wave = generated_wave * rms / target_rms
166
+
167
+ # # wav -> numpy
168
+ # generated_wave = generated_wave.squeeze().cpu().numpy()
169
+
170
+ # if remove_silence:
171
+ # gr.Info("Removing audio silences... This may take a moment")
172
+ # non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
173
+ # non_silent_wave = np.array([])
174
+ # for interval in non_silent_intervals:
175
+ # start, end = interval
176
+ # non_silent_wave = np.concatenate(
177
+ # [non_silent_wave, generated_wave[start:end]]
178
+ # )
179
+ # generated_wave = non_silent_wave
180
+
181
+ # # spectogram
182
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
183
+ # spectrogram_path = tmp_spectrogram.name
184
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
185
+
186
+ # return (target_sample_rate, generated_wave), spectrogram_path
187
+
188
+
189
+ # with gr.Blocks() as app:
190
+ # gr.Markdown(
191
+ # """
192
+ # # Antriksh AI
193
+
194
+ # """
195
+ # )
196
+
197
+ # # Image
198
+ # gr.Image(value="C:\\Users\\USER\\OneDrive\\Documents\\logo.jpg", width=300, height= 150 )
199
+
200
+ # ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
201
+ # gen_text_input = gr.Textbox(label="Text to Generate (max 200 chars.)", lines=4)
202
+ # model_choice = gr.Radio(
203
+ # choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
204
+ # )
205
+ # generate_btn = gr.Button("Synthesize", variant="primary")
206
+ # with gr.Accordion("Advanced Settings", open=False):
207
+ # ref_text_input = gr.Textbox(
208
+ # label="Reference Text",
209
+ # info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
210
+ # lines=2,
211
+ # )
212
+ # remove_silence = gr.Checkbox(
213
+ # label="Remove Silences",
214
+ # info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
215
+ # value=True,
216
+ # )
217
+
218
+ # audio_output = gr.Audio(label="Synthesized Audio")
219
+ # spectrogram_output = gr.Image(label="Spectrogram")
220
+
221
+ # generate_btn.click(
222
+ # infer,
223
+ # inputs=[
224
+ # ref_audio_input,
225
+ # ref_text_input,
226
+ # gen_text_input,
227
+ # model_choice,
228
+ # remove_silence,
229
+ # ],
230
+ # outputs=[audio_output, spectrogram_output],
231
+ # )
232
+
233
+
234
+ # @click.command()
235
+ # @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
236
+ # @click.option("--host", "-H", default=None, help="Host to run the app on")
237
+ # @click.option(
238
+ # "--share",
239
+ # "-s",
240
+ # default=True,
241
+ # is_flag=True,
242
+ # help="Share the app via Gradio share link",
243
+ # )
244
+ # @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
245
+ # def main(port, host, share, api):
246
+ # global app
247
+ # print(f"Starting app...")
248
+ # app.queue(api_open=api).launch(
249
+ # server_name=host, server_port=port, share=True, show_api=api
250
+ # )
251
+
252
+
253
+ # if __name__ == "__main__":
254
+ # main()
255
+
256
+ import re
257
+ import torch
258
+ import torchaudio
259
+ import gradio as gr
260
+ import numpy as np
261
+ import tempfile
262
+ from einops import rearrange
263
+ from vocos import Vocos
264
+ from pydub import AudioSegment, silence
265
+ from model import CFM, UNetT, DiT, MMDiT
266
+ from cached_path import cached_path
267
+ from model.utils import (
268
+ load_checkpoint,
269
+ get_tokenizer,
270
+ convert_char_to_pinyin,
271
+ save_spectrogram,
272
+ )
273
+ from transformers import pipeline
274
+ import click
275
+ import soundfile as sf
276
+
277
+ try:
278
+ import spaces
279
+ USING_SPACES = True
280
+ except ImportError:
281
+ USING_SPACES = False
282
+
283
+ def gpu_decorator(func):
284
+ if USING_SPACES:
285
+ return spaces.GPU(func)
286
+ else:
287
+ return func
288
+
289
+ device = (
290
+ "cuda"
291
+ if torch.cuda.is_available()
292
+ else "mps" if torch.backends.mps.is_available() else "cpu"
293
+ )
294
+
295
+ print(f"Using {device} device")
296
+
297
+ pipe = pipeline(
298
+ "automatic-speech-recognition",
299
+ model="openai/whisper-large-v3-turbo",
300
+ torch_dtype=torch.float16,
301
+ device=device,
302
+ )
303
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
304
+
305
+ # --------------------- Settings -------------------- #
306
+
307
+ target_sample_rate = 24000
308
+ n_mel_channels = 100
309
+ hop_length = 256
310
+ target_rms = 0.1
311
+ nfe_step = 32 # 16, 32
312
+ cfg_strength = 2.0
313
+ ode_method = "euler"
314
+ sway_sampling_coef = -1.0
315
+ speed = 1.0
316
+ fix_duration = None
317
+
318
+
319
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
320
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
321
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
322
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
323
+ model = CFM(
324
+ transformer=model_cls(
325
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
326
+ ),
327
+ mel_spec_kwargs=dict(
328
+ target_sample_rate=target_sample_rate,
329
+ n_mel_channels=n_mel_channels,
330
+ hop_length=hop_length,
331
+ ),
332
+ odeint_kwargs=dict(
333
+ method=ode_method,
334
+ ),
335
+ vocab_char_map=vocab_char_map,
336
+ ).to(device)
337
+
338
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
339
+
340
+ return model
341
+
342
+
343
+ # load models
344
+ F5TTS_model_cfg = dict(
345
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
346
+ )
347
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
348
+
349
+ F5TTS_ema_model = load_model(
350
+ "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
351
+ )
352
+ E2TTS_ema_model = load_model(
353
+ "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
354
+ )
355
+
356
+ def chunk_text(text, max_chars=135):
357
+ """
358
+ Splits the input text into chunks, each with a maximum number of characters.
359
+
360
+ Args:
361
+ text (str): The text to be split.
362
+ max_chars (int): The maximum number of characters per chunk.
363
+
364
+ Returns:
365
+ List[str]: A list of text chunks.
366
+ """
367
+ chunks = []
368
+ current_chunk = ""
369
+ # Split the text into sentences based on punctuation followed by whitespace
370
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
371
+
372
+ for sentence in sentences:
373
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
374
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
375
+ else:
376
+ if current_chunk:
377
+ chunks.append(current_chunk.strip())
378
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
379
+
380
+ if current_chunk:
381
+ chunks.append(current_chunk.strip())
382
+
383
+ return chunks
384
+
385
+ @gpu_decorator
386
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration=0.15, progress=gr.Progress()):
387
+ if exp_name == "F5-TTS":
388
+ ema_model = F5TTS_ema_model
389
+ elif exp_name == "E2-TTS":
390
+ ema_model = E2TTS_ema_model
391
+
392
+ audio, sr = ref_audio
393
+ if audio.shape[0] > 1:
394
+ audio = torch.mean(audio, dim=0, keepdim=True)
395
+
396
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
397
+ if rms < target_rms:
398
+ audio = audio * target_rms / rms
399
+ if sr != target_sample_rate:
400
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
401
+ audio = resampler(audio)
402
+ audio = audio.to(device)
403
+
404
+ generated_waves = []
405
+ spectrograms = []
406
+
407
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
408
+ # Prepare the text
409
+ if len(ref_text[-1].encode('utf-8')) == 1:
410
+ ref_text = ref_text + " "
411
+ text_list = [ref_text + gen_text]
412
+ final_text_list = convert_char_to_pinyin(text_list)
413
+
414
+ # Calculate duration
415
+ ref_audio_len = audio.shape[-1] // hop_length
416
+ zh_pause_punc = r"。,、;:?!"
417
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
418
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
419
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
420
+
421
+ # inference
422
+ with torch.inference_mode():
423
+ generated, _ = ema_model.sample(
424
+ cond=audio,
425
+ text=final_text_list,
426
+ duration=duration,
427
+ steps=nfe_step,
428
+ cfg_strength=cfg_strength,
429
+ sway_sampling_coef=sway_sampling_coef,
430
+ )
431
+
432
+ generated = generated[:, ref_audio_len:, :]
433
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
434
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
435
+ if rms < target_rms:
436
+ generated_wave = generated_wave * rms / target_rms
437
+
438
+ # wav -> numpy
439
+ generated_wave = generated_wave.squeeze().cpu().numpy()
440
+
441
+ generated_waves.append(generated_wave)
442
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
443
+
444
+ # Combine all generated waves with cross-fading
445
+ if cross_fade_duration <= 0:
446
+ # Simply concatenate
447
+ final_wave = np.concatenate(generated_waves)
448
+ else:
449
+ final_wave = generated_waves[0]
450
+ for i in range(1, len(generated_waves)):
451
+ prev_wave = final_wave
452
+ next_wave = generated_waves[i]
453
+
454
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
455
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
456
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
457
+
458
+ if cross_fade_samples <= 0:
459
+ # No overlap possible, concatenate
460
+ final_wave = np.concatenate([prev_wave, next_wave])
461
+ continue
462
+
463
+ # Overlapping parts
464
+ prev_overlap = prev_wave[-cross_fade_samples:]
465
+ next_overlap = next_wave[:cross_fade_samples]
466
+
467
+ # Fade out and fade in
468
+ fade_out = np.linspace(1, 0, cross_fade_samples)
469
+ fade_in = np.linspace(0, 1, cross_fade_samples)
470
+
471
+ # Cross-faded overlap
472
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
473
+
474
+ # Combine
475
+ new_wave = np.concatenate([
476
+ prev_wave[:-cross_fade_samples],
477
+ cross_faded_overlap,
478
+ next_wave[cross_fade_samples:]
479
+ ])
480
+
481
+ final_wave = new_wave
482
+
483
+ # Remove silence
484
+ if remove_silence:
485
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
486
+ sf.write(f.name, final_wave, target_sample_rate)
487
+ aseg = AudioSegment.from_file(f.name)
488
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
489
+ non_silent_wave = AudioSegment.silent(duration=0)
490
+ for non_silent_seg in non_silent_segs:
491
+ non_silent_wave += non_silent_seg
492
+ aseg = non_silent_wave
493
+ aseg.export(f.name, format="wav")
494
+ final_wave, _ = torchaudio.load(f.name)
495
+ final_wave = final_wave.squeeze().cpu().numpy()
496
+
497
+ # Create a combined spectrogram
498
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
499
+
500
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
501
+ spectrogram_path = tmp_spectrogram.name
502
+ save_spectrogram(combined_spectrogram, spectrogram_path)
503
+
504
+ return (target_sample_rate, final_wave), spectrogram_path
505
+
506
+ @gpu_decorator
507
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, cross_fade_duration=0.15):
508
+
509
+ print(gen_text)
510
+
511
+ gr.Info("Converting audio...")
512
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
513
+ aseg = AudioSegment.from_file(ref_audio_orig)
514
+
515
+ non_silent_segs = silence.split_on_silence(
516
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000
517
+ )
518
+ non_silent_wave = AudioSegment.silent(duration=0)
519
+ for non_silent_seg in non_silent_segs:
520
+ non_silent_wave += non_silent_seg
521
+ aseg = non_silent_wave
522
+
523
+ audio_duration = len(aseg)
524
+ if audio_duration > 15000:
525
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
526
+ aseg = aseg[:15000]
527
+ aseg.export(f.name, format="wav")
528
+ ref_audio = f.name
529
+
530
+ if not ref_text.strip():
531
+ gr.Info("No reference text provided, transcribing reference audio...")
532
+ ref_text = pipe(
533
+ ref_audio,
534
+ chunk_length_s=30,
535
+ batch_size=128,
536
+ generate_kwargs={"task": "transcribe"},
537
+ return_timestamps=False,
538
+ )["text"].strip()
539
+ gr.Info("Finished transcription")
540
+ else:
541
+ gr.Info("Using custom reference text...")
542
+
543
+ # Add the functionality to ensure it ends with ". "
544
+ if not ref_text.endswith(". "):
545
+ if ref_text.endswith("."):
546
+ ref_text += " "
547
+ else:
548
+ ref_text += ". "
549
+
550
+ audio, sr = torchaudio.load(ref_audio)
551
+
552
+ # Use the new chunk_text function to split gen_text
553
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
554
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
555
+ print('ref_text', ref_text)
556
+ for i, batch_text in enumerate(gen_text_batches):
557
+ print(f'gen_text {i}', batch_text)
558
+
559
+ gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
560
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence, cross_fade_duration)
561
+
562
+
563
+ @gpu_decorator
564
+ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
565
+ # Split the script into speaker blocks
566
+ speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
567
+ speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
568
+
569
+ generated_audio_segments = []
570
+
571
+ for i in range(0, len(speaker_blocks), 2):
572
+ speaker = speaker_blocks[i]
573
+ text = speaker_blocks[i+1].strip()
574
+
575
+ # Determine which speaker is talking
576
+ if speaker == speaker1_name:
577
+ ref_audio = ref_audio1
578
+ ref_text = ref_text1
579
+ elif speaker == speaker2_name:
580
+ ref_audio = ref_audio2
581
+ ref_text = ref_text2
582
+ else:
583
+ continue # Skip if the speaker is neither speaker1 nor speaker2
584
+
585
+ # Generate audio for this block
586
+ audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
587
+
588
+ # Convert the generated audio to a numpy array
589
+ sr, audio_data = audio
590
+
591
+ # Save the audio data as a WAV file
592
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
593
+ sf.write(temp_file.name, audio_data, sr)
594
+ audio_segment = AudioSegment.from_wav(temp_file.name)
595
+
596
+ generated_audio_segments.append(audio_segment)
597
+
598
+ # Add a short pause between speakers
599
+ pause = AudioSegment.silent(duration=500) # 500ms pause
600
+ generated_audio_segments.append(pause)
601
+
602
+ # Concatenate all audio segments
603
+ final_podcast = sum(generated_audio_segments)
604
+
605
+ # Export the final podcast
606
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
607
+ podcast_path = temp_file.name
608
+ final_podcast.export(podcast_path, format="wav")
609
+
610
+ return podcast_path
611
+
612
+ def parse_speechtypes_text(gen_text):
613
+ # Pattern to find (Emotion)
614
+ pattern = r'\((.*?)\)'
615
+
616
+ # Split the text by the pattern
617
+ tokens = re.split(pattern, gen_text)
618
+
619
+ segments = []
620
+
621
+ current_emotion = 'Regular'
622
+
623
+ for i in range(len(tokens)):
624
+ if i % 2 == 0:
625
+ # This is text
626
+ text = tokens[i].strip()
627
+ if text:
628
+ segments.append({'emotion': current_emotion, 'text': text})
629
+ else:
630
+ # This is emotion
631
+ emotion = tokens[i].strip()
632
+ current_emotion = emotion
633
+
634
+ return segments
635
+
636
+ def update_speed(new_speed):
637
+ global speed
638
+ speed = new_speed
639
+ return f"Speed set to: {speed}"
640
+
641
+ with gr.Blocks() as app_credits:
642
+ gr.Markdown("""
643
+ # Credits
644
+
645
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
646
+ * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
647
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation
648
+ """)
649
+ with gr.Blocks() as app_tts:
650
+ gr.Markdown("# Batched TTS")
651
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
652
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
653
+ model_choice = gr.Radio(
654
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
655
+ )
656
+ generate_btn = gr.Button("Synthesize", variant="primary")
657
+ with gr.Accordion("Advanced Settings", open=False):
658
+ ref_text_input = gr.Textbox(
659
+ label="Reference Text",
660
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
661
+ lines=2,
662
+ )
663
+ remove_silence = gr.Checkbox(
664
+ label="Remove Silences",
665
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
666
+ value=False,
667
+ )
668
+ speed_slider = gr.Slider(
669
+ label="Speed",
670
+ minimum=0.3,
671
+ maximum=2.0,
672
+ value=speed,
673
+ step=0.1,
674
+ info="Adjust the speed of the audio.",
675
+ )
676
+ cross_fade_duration_slider = gr.Slider(
677
+ label="Cross-Fade Duration (s)",
678
+ minimum=0.0,
679
+ maximum=1.0,
680
+ value=0.15,
681
+ step=0.01,
682
+ info="Set the duration of the cross-fade between audio clips.",
683
+ )
684
+ speed_slider.change(update_speed, inputs=speed_slider)
685
+
686
+ audio_output = gr.Audio(label="Synthesized Audio")
687
+ spectrogram_output = gr.Image(label="Spectrogram")
688
+
689
+ generate_btn.click(
690
+ infer,
691
+ inputs=[
692
+ ref_audio_input,
693
+ ref_text_input,
694
+ gen_text_input,
695
+ model_choice,
696
+ remove_silence,
697
+ cross_fade_duration_slider,
698
+ ],
699
+ outputs=[audio_output, spectrogram_output],
700
+ )
701
+
702
+ with gr.Blocks() as app_podcast:
703
+ gr.Markdown("# Podcast Generation")
704
+ speaker1_name = gr.Textbox(label="Speaker 1 Name")
705
+ ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
706
+ ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
707
+
708
+ speaker2_name = gr.Textbox(label="Speaker 2 Name")
709
+ ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
710
+ ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
711
+
712
+ script_input = gr.Textbox(label="Podcast Script", lines=10,
713
+ placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
714
+
715
+ podcast_model_choice = gr.Radio(
716
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
717
+ )
718
+ podcast_remove_silence = gr.Checkbox(
719
+ label="Remove Silences",
720
+ value=True,
721
+ )
722
+ generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
723
+ podcast_output = gr.Audio(label="Generated Podcast")
724
+
725
+ def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
726
+ return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
727
+
728
+ generate_podcast_btn.click(
729
+ podcast_generation,
730
+ inputs=[
731
+ script_input,
732
+ speaker1_name,
733
+ ref_audio_input1,
734
+ ref_text_input1,
735
+ speaker2_name,
736
+ ref_audio_input2,
737
+ ref_text_input2,
738
+ podcast_model_choice,
739
+ podcast_remove_silence,
740
+ ],
741
+ outputs=podcast_output,
742
+ )
743
+
744
+ def parse_emotional_text(gen_text):
745
+ # Pattern to find (Emotion)
746
+ pattern = r'\((.*?)\)'
747
+
748
+ # Split the text by the pattern
749
+ tokens = re.split(pattern, gen_text)
750
+
751
+ segments = []
752
+
753
+ current_emotion = 'Regular'
754
+
755
+ for i in range(len(tokens)):
756
+ if i % 2 == 0:
757
+ # This is text
758
+ text = tokens[i].strip()
759
+ if text:
760
+ segments.append({'emotion': current_emotion, 'text': text})
761
+ else:
762
+ # This is emotion
763
+ emotion = tokens[i].strip()
764
+ current_emotion = emotion
765
+
766
+ return segments
767
+
768
+ with gr.Blocks() as app_emotional:
769
+ # New section for emotional generation
770
+ gr.Markdown(
771
+ """
772
+ # Multiple Speech-Type Generation
773
+
774
+ This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
775
+
776
+ **Example Input:**
777
+
778
+ (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
779
+ """
780
+ )
781
+
782
+ gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
783
+
784
+ # Regular speech type (mandatory)
785
+ with gr.Row():
786
+ regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
787
+ regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
788
+ regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
789
+
790
+ # Additional speech types (up to 99 more)
791
+ max_speech_types = 100
792
+ speech_type_names = []
793
+ speech_type_audios = []
794
+ speech_type_ref_texts = []
795
+ speech_type_delete_btns = []
796
+
797
+ for i in range(max_speech_types - 1):
798
+ with gr.Row():
799
+ name_input = gr.Textbox(label='Speech Type Name', visible=False)
800
+ audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
801
+ ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
802
+ delete_btn = gr.Button("Delete", variant="secondary", visible=False)
803
+ speech_type_names.append(name_input)
804
+ speech_type_audios.append(audio_input)
805
+ speech_type_ref_texts.append(ref_text_input)
806
+ speech_type_delete_btns.append(delete_btn)
807
+
808
+ # Button to add speech type
809
+ add_speech_type_btn = gr.Button("Add Speech Type")
810
+
811
+ # Keep track of current number of speech types
812
+ speech_type_count = gr.State(value=0)
813
+
814
+ # Function to add a speech type
815
+ def add_speech_type_fn(speech_type_count):
816
+ if speech_type_count < max_speech_types - 1:
817
+ speech_type_count += 1
818
+ # Prepare updates for the components
819
+ name_updates = []
820
+ audio_updates = []
821
+ ref_text_updates = []
822
+ delete_btn_updates = []
823
+ for i in range(max_speech_types - 1):
824
+ if i < speech_type_count:
825
+ name_updates.append(gr.update(visible=True))
826
+ audio_updates.append(gr.update(visible=True))
827
+ ref_text_updates.append(gr.update(visible=True))
828
+ delete_btn_updates.append(gr.update(visible=True))
829
+ else:
830
+ name_updates.append(gr.update())
831
+ audio_updates.append(gr.update())
832
+ ref_text_updates.append(gr.update())
833
+ delete_btn_updates.append(gr.update())
834
+ else:
835
+ # Optionally, show a warning
836
+ # gr.Warning("Maximum number of speech types reached.")
837
+ name_updates = [gr.update() for _ in range(max_speech_types - 1)]
838
+ audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
839
+ ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
840
+ delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
841
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
842
+
843
+ add_speech_type_btn.click(
844
+ add_speech_type_fn,
845
+ inputs=speech_type_count,
846
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
847
+ )
848
+
849
+ # Function to delete a speech type
850
+ def make_delete_speech_type_fn(index):
851
+ def delete_speech_type_fn(speech_type_count):
852
+ # Prepare updates
853
+ name_updates = []
854
+ audio_updates = []
855
+ ref_text_updates = []
856
+ delete_btn_updates = []
857
+
858
+ for i in range(max_speech_types - 1):
859
+ if i == index:
860
+ name_updates.append(gr.update(visible=False, value=''))
861
+ audio_updates.append(gr.update(visible=False, value=None))
862
+ ref_text_updates.append(gr.update(visible=False, value=''))
863
+ delete_btn_updates.append(gr.update(visible=False))
864
+ else:
865
+ name_updates.append(gr.update())
866
+ audio_updates.append(gr.update())
867
+ ref_text_updates.append(gr.update())
868
+ delete_btn_updates.append(gr.update())
869
+
870
+ speech_type_count = max(0, speech_type_count - 1)
871
+
872
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
873
+
874
+ return delete_speech_type_fn
875
+
876
+ for i, delete_btn in enumerate(speech_type_delete_btns):
877
+ delete_fn = make_delete_speech_type_fn(i)
878
+ delete_btn.click(
879
+ delete_fn,
880
+ inputs=speech_type_count,
881
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
882
+ )
883
+
884
+ # Text input for the prompt
885
+ gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
886
+
887
+ # Model choice
888
+ model_choice_emotional = gr.Radio(
889
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
890
+ )
891
+
892
+ with gr.Accordion("Advanced Settings", open=False):
893
+ remove_silence_emotional = gr.Checkbox(
894
+ label="Remove Silences",
895
+ value=True,
896
+ )
897
+
898
+ # Generate button
899
+ generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
900
+
901
+ # Output audio
902
+ audio_output_emotional = gr.Audio(label="Synthesized Audio")
903
+ @gpu_decorator
904
+ def generate_emotional_speech(
905
+ regular_audio,
906
+ regular_ref_text,
907
+ gen_text,
908
+ *args,
909
+ ):
910
+ num_additional_speech_types = max_speech_types - 1
911
+ speech_type_names_list = args[:num_additional_speech_types]
912
+ speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
913
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
914
+ model_choice = args[3 * num_additional_speech_types]
915
+ remove_silence = args[3 * num_additional_speech_types + 1]
916
+
917
+ # Collect the speech types and their audios into a dict
918
+ speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
919
+
920
+ for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
921
+ if name_input and audio_input:
922
+ speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
923
+
924
+ # Parse the gen_text into segments
925
+ segments = parse_speechtypes_text(gen_text)
926
+
927
+ # For each segment, generate speech
928
+ generated_audio_segments = []
929
+ current_emotion = 'Regular'
930
+
931
+ for segment in segments:
932
+ emotion = segment['emotion']
933
+ text = segment['text']
934
+
935
+ if emotion in speech_types:
936
+ current_emotion = emotion
937
+ else:
938
+ # If emotion not available, default to Regular
939
+ current_emotion = 'Regular'
940
+
941
+ ref_audio = speech_types[current_emotion]['audio']
942
+ ref_text = speech_types[current_emotion].get('ref_text', '')
943
+
944
+ # Generate speech for this segment
945
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, 0)
946
+ sr, audio_data = audio
947
+
948
+ generated_audio_segments.append(audio_data)
949
+
950
+ # Concatenate all audio segments
951
+ if generated_audio_segments:
952
+ final_audio_data = np.concatenate(generated_audio_segments)
953
+ return (sr, final_audio_data)
954
+ else:
955
+ gr.Warning("No audio generated.")
956
+ return None
957
+
958
+ generate_emotional_btn.click(
959
+ generate_emotional_speech,
960
+ inputs=[
961
+ regular_audio,
962
+ regular_ref_text,
963
+ gen_text_input_emotional,
964
+ ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
965
+ model_choice_emotional,
966
+ remove_silence_emotional,
967
+ ],
968
+ outputs=audio_output_emotional,
969
+ )
970
+
971
+ # Validation function to disable Generate button if speech types are missing
972
+ def validate_speech_types(
973
+ gen_text,
974
+ regular_name,
975
+ *args
976
+ ):
977
+ num_additional_speech_types = max_speech_types - 1
978
+ speech_type_names_list = args[:num_additional_speech_types]
979
+
980
+ # Collect the speech types names
981
+ speech_types_available = set()
982
+ if regular_name:
983
+ speech_types_available.add(regular_name)
984
+ for name_input in speech_type_names_list:
985
+ if name_input:
986
+ speech_types_available.add(name_input)
987
+
988
+ # Parse the gen_text to get the speech types used
989
+ segments = parse_emotional_text(gen_text)
990
+ speech_types_in_text = set(segment['emotion'] for segment in segments)
991
+
992
+ # Check if all speech types in text are available
993
+ missing_speech_types = speech_types_in_text - speech_types_available
994
+
995
+ if missing_speech_types:
996
+ # Disable the generate button
997
+ return gr.update(interactive=False)
998
+ else:
999
+ # Enable the generate button
1000
+ return gr.update(interactive=True)
1001
+
1002
+ gen_text_input_emotional.change(
1003
+ validate_speech_types,
1004
+ inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
1005
+ outputs=generate_emotional_btn
1006
+ )
1007
+ with gr.Blocks() as app:
1008
+ gr.Markdown(
1009
+ """
1010
+ # Antriksh AI
1011
+ """
1012
+ )
1013
+
1014
+ # Add the image here
1015
+ gr.Image(
1016
+ value="C:\\Users\\USER\\Downloads\\logo-removebg-preview.png",
1017
+ label="AI System Logo",
1018
+ show_label=False,
1019
+ width=300,
1020
+ height=150
1021
+ )
1022
+
1023
+ gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
1024
+
1025
+
1026
+ @click.command()
1027
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
1028
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
1029
+ @click.option(
1030
+ "--share",
1031
+ "-s",
1032
+ default=False,
1033
+ is_flag=True,
1034
+ help="Share the app via Gradio share link",
1035
+ )
1036
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
1037
+ def main(port, host, share, api):
1038
+ global app
1039
+ print(f"Starting app...")
1040
+ app.queue(api_open=api).launch(
1041
+ server_name=host, server_port=port, share=share, show_api=api
1042
+ )
1043
+
1044
+
1045
+ if __name__ == "__main__":
1046
+ if not USING_SPACES:
1047
+ main()
1048
+ else:
1049
+ app.queue().launch()
model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from model.cfm import CFM
2
+
3
+ from model.backbones.unett import UNetT
4
+ from model.backbones.dit import DiT
5
+ from model.backbones.mmdit import MMDiT
6
+
7
+ from model.trainer import Trainer
model/backbones/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Backbones quick introduction
2
+
3
+
4
+ ### unett.py
5
+ - flat unet transformer
6
+ - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
+
9
+ ### dit.py
10
+ - adaln-zero dit
11
+ - embedded timestep as condition
12
+ - concatted noised_input + masked_cond + embedded_text, linear proj in
13
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
14
+ - possible long skip connection (first layer to last layer)
15
+
16
+ ### mmdit.py
17
+ - sd3 structure
18
+ - timestep as condition
19
+ - left stream: text embedded and applied a abs pos emb
20
+ - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
model/backbones/dit.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+ from einops import repeat
17
+
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from model.modules import (
21
+ TimestepEmbedding,
22
+ ConvNeXtV2Block,
23
+ ConvPositionEmbedding,
24
+ DiTBlock,
25
+ AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+ class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
+ super().__init__()
35
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
+
37
+ if conv_layers > 0:
38
+ self.extra_modeling = True
39
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
42
+ else:
43
+ self.extra_modeling = False
44
+
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
50
+
51
+ if drop_text: # cfg for text
52
+ text = torch.zeros_like(text)
53
+
54
+ text = self.text_embed(text) # b n -> b n d
55
+
56
+ # possible extra modeling
57
+ if self.extra_modeling:
58
+ # sinus pos emb
59
+ batch_start = torch.zeros((batch,), dtype=torch.long)
60
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
61
+ text_pos_embed = self.freqs_cis[pos_idx]
62
+ text = text + text_pos_embed
63
+
64
+ # convnextv2 blocks
65
+ text = self.text_blocks(text)
66
+
67
+ return text
68
+
69
+
70
+ # noised input audio and context mixing embedding
71
+
72
+ class InputEmbedding(nn.Module):
73
+ def __init__(self, mel_dim, text_dim, out_dim):
74
+ super().__init__()
75
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
+
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
+ if drop_audio_cond: # cfg for cond audio
80
+ cond = torch.zeros_like(cond)
81
+
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
+ x = self.conv_pos_embed(x) + x
84
+ return x
85
+
86
+
87
+ # Transformer backbone using DiT blocks
88
+
89
+ class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.time_embed = TimestepEmbedding(dim)
98
+ if text_dim is None:
99
+ text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
+
103
+ self.rotary_embed = RotaryEmbedding(dim_head)
104
+
105
+ self.dim = dim
106
+ self.depth = depth
107
+
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
+ )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
+ self.proj_out = nn.Linear(dim, mel_dim)
124
+
125
+ def forward(
126
+ self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
+ drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
+ ):
135
+ batch, seq_len = x.shape[0], x.shape[1]
136
+ if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
+ t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
+
146
+ if self.long_skip_connection is not None:
147
+ residual = x
148
+
149
+ for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
+
152
+ if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
+
155
+ x = self.norm_out(x, t)
156
+ output = self.proj_out(x)
157
+
158
+ return output
model/backbones/mmdit.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from einops import repeat
16
+
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+
19
+ from model.modules import (
20
+ TimestepEmbedding,
21
+ ConvPositionEmbedding,
22
+ MMDiTBlock,
23
+ AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
25
+ )
26
+
27
+
28
+ # text embedding
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+ class AudioEmbedding(nn.Module):
58
+ def __init__(self, in_dim, out_dim):
59
+ super().__init__()
60
+ self.linear = nn.Linear(2 * in_dim, out_dim)
61
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
+
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
+ if drop_audio_cond:
65
+ cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
+ x = self.linear(x)
68
+ x = self.conv_pos_embed(x) + x
69
+ return x
70
+
71
+
72
+ # Transformer backbone using MM-DiT blocks
73
+
74
+ class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.time_embed = TimestepEmbedding(dim)
82
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
83
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
84
+
85
+ self.rotary_embed = RotaryEmbedding(dim_head)
86
+
87
+ self.dim = dim
88
+ self.depth = depth
89
+
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
+ )
100
+ for i in range(depth)
101
+ ]
102
+ )
103
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
104
+ self.proj_out = nn.Linear(dim, mel_dim)
105
+
106
+ def forward(
107
+ self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
+ drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
+ ):
116
+ batch = x.shape[0]
117
+ if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
+
120
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
+ t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
+
125
+ seq_len = x.shape[1]
126
+ text_len = text.shape[1]
127
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
+ for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
+
133
+ x = self.norm_out(x, t)
134
+ output = self.proj_out(x)
135
+
136
+ return output
model/backbones/unett.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from einops import repeat, pack, unpack
18
+
19
+ from x_transformers import RMSNorm
20
+ from x_transformers.x_transformers import RotaryEmbedding
21
+
22
+ from model.modules import (
23
+ TimestepEmbedding,
24
+ ConvNeXtV2Block,
25
+ ConvPositionEmbedding,
26
+ Attention,
27
+ AttnProcessor,
28
+ FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
30
+ )
31
+
32
+
33
+ # Text embedding
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+
40
+ if conv_layers > 0:
41
+ self.extra_modeling = True
42
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
53
+
54
+ if drop_text: # cfg for text
55
+ text = torch.zeros_like(text)
56
+
57
+ text = self.text_embed(text) # b n -> b n d
58
+
59
+ # possible extra modeling
60
+ if self.extra_modeling:
61
+ # sinus pos emb
62
+ batch_start = torch.zeros((batch,), dtype=torch.long)
63
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
+ text_pos_embed = self.freqs_cis[pos_idx]
65
+ text = text + text_pos_embed
66
+
67
+ # convnextv2 blocks
68
+ text = self.text_blocks(text)
69
+
70
+ return text
71
+
72
+
73
+ # noised input audio and context mixing embedding
74
+
75
+ class InputEmbedding(nn.Module):
76
+ def __init__(self, mel_dim, text_dim, out_dim):
77
+ super().__init__()
78
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
+
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
+ if drop_audio_cond: # cfg for cond audio
83
+ cond = torch.zeros_like(cond)
84
+
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
+ x = self.conv_pos_embed(x) + x
87
+ return x
88
+
89
+
90
+ # Flat UNet Transformer backbone
91
+
92
+ class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
97
+ ):
98
+ super().__init__()
99
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
100
+
101
+ self.time_embed = TimestepEmbedding(dim)
102
+ if text_dim is None:
103
+ text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
+
107
+ self.rotary_embed = RotaryEmbedding(dim_head)
108
+
109
+ # transformer layers & skip connections
110
+
111
+ self.dim = dim
112
+ self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
+
115
+ self.depth = depth
116
+ self.layers = nn.ModuleList([])
117
+
118
+ for idx in range(depth):
119
+ is_later_half = idx >= (depth // 2)
120
+
121
+ attn_norm = RMSNorm(dim)
122
+ attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
+
130
+ ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
142
+
143
+ self.norm_out = RMSNorm(dim)
144
+ self.proj_out = nn.Linear(dim, mel_dim)
145
+
146
+ def forward(
147
+ self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
+ drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
+ ):
156
+ batch, seq_len = x.shape[0], x.shape[1]
157
+ if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
+ t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
+
165
+ # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
+ if mask is not None:
168
+ mask = F.pad(mask, (1, 0), value=1)
169
+
170
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
+
172
+ # flat unet transformer
173
+ skip_connect_type = self.skip_connect_type
174
+ skips = []
175
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
176
+ layer = idx + 1
177
+
178
+ # skip connection logic
179
+ is_first_half = layer <= (self.depth // 2)
180
+ is_later_half = not is_first_half
181
+
182
+ if is_first_half:
183
+ skips.append(x)
184
+
185
+ if is_later_half:
186
+ skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
+ x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
+ x = x + skip
192
+
193
+ # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
+ x = ff(ff_norm(x)) + x
196
+
197
+ assert len(skips) == 0
198
+
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
+
201
+ return self.proj_out(x)
model/cfm.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Callable
12
+ from random import random
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+ from torchdiffeq import odeint
20
+
21
+ from einops import rearrange
22
+
23
+ from model.modules import MelSpec
24
+
25
+ from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
30
+
31
+
32
+ class CFM(nn.Module):
33
+ def __init__(
34
+ self,
35
+ transformer: nn.Module,
36
+ sigma = 0.,
37
+ odeint_kwargs: dict = dict(
38
+ # atol = 1e-5,
39
+ # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
+ ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
+ mel_spec_module: nn.Module | None = None,
46
+ mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
+ ):
50
+ super().__init__()
51
+
52
+ self.frac_lengths_mask = frac_lengths_mask
53
+
54
+ # mel spec
55
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
+ self.num_channels = num_channels
58
+
59
+ # classifier-free guidance
60
+ self.audio_drop_prob = audio_drop_prob
61
+ self.cond_drop_prob = cond_drop_prob
62
+
63
+ # transformer
64
+ self.transformer = transformer
65
+ dim = transformer.dim
66
+ self.dim = dim
67
+
68
+ # conditional flow related
69
+ self.sigma = sigma
70
+
71
+ # sampling related
72
+ self.odeint_kwargs = odeint_kwargs
73
+
74
+ # vocab map for tokenization
75
+ self.vocab_char_map = vocab_char_map
76
+
77
+ @property
78
+ def device(self):
79
+ return next(self.parameters()).device
80
+
81
+ @torch.no_grad()
82
+ def sample(
83
+ self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
+ *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
+ seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
98
+ edit_mask = None,
99
+ ):
100
+ self.eval()
101
+
102
+ # raw wave
103
+
104
+ if cond.ndim == 2:
105
+ cond = self.mel_spec(cond)
106
+ cond = rearrange(cond, 'b d n -> b n d')
107
+ assert cond.shape[-1] == self.num_channels
108
+
109
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
110
+ if not exists(lens):
111
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
112
+
113
+ # text
114
+
115
+ if isinstance(text, list):
116
+ if exists(self.vocab_char_map):
117
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
118
+ else:
119
+ text = list_str_to_tensor(text).to(device)
120
+ assert text.shape[0] == batch
121
+
122
+ if exists(text):
123
+ text_lens = (text != -1).sum(dim = -1)
124
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
125
+
126
+ # duration
127
+
128
+ cond_mask = lens_to_mask(lens)
129
+ if edit_mask is not None:
130
+ cond_mask = cond_mask & edit_mask
131
+
132
+ if isinstance(duration, int):
133
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
134
+
135
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
136
+ duration = duration.clamp(max = max_duration)
137
+ max_duration = duration.amax()
138
+
139
+ # duplicate test corner for inner time step oberservation
140
+ if duplicate_test:
141
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
142
+
143
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
144
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
145
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
146
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
147
+
148
+ if batch > 1:
149
+ mask = lens_to_mask(duration)
150
+ else: # save memory and speed up, as single inference need no mask currently
151
+ mask = None
152
+
153
+ # test for no ref audio
154
+ if no_ref_audio:
155
+ cond = torch.zeros_like(cond)
156
+
157
+ # neural ode
158
+
159
+ def fn(t, x):
160
+ # at each step, conditioning is fixed
161
+ # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
162
+
163
+ # predict flow
164
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
165
+ if cfg_strength < 1e-5:
166
+ return pred
167
+
168
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
169
+ return pred + (pred - null_pred) * cfg_strength
170
+
171
+ # noise input
172
+ # to make sure batch inference result is same with different batch size, and for sure single inference
173
+ # still some difference maybe due to convolutional layers
174
+ y0 = []
175
+ for dur in duration:
176
+ if exists(seed):
177
+ torch.manual_seed(seed)
178
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
179
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
+
181
+ t_start = 0
182
+
183
+ # duplicate test corner for inner time step oberservation
184
+ if duplicate_test:
185
+ t_start = t_inter
186
+ y0 = (1 - t_start) * y0 + t_start * test_cond
187
+ steps = int(steps * (1 - t_start))
188
+
189
+ t = torch.linspace(t_start, 1, steps, device = self.device)
190
+ if sway_sampling_coef is not None:
191
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
+
193
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
194
+
195
+ sampled = trajectory[-1]
196
+ out = sampled
197
+ out = torch.where(cond_mask, cond, out)
198
+
199
+ if exists(vocoder):
200
+ out = rearrange(out, 'b n d -> b d n')
201
+ out = vocoder(out)
202
+
203
+ return out, trajectory
204
+
205
+ def forward(
206
+ self,
207
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
208
+ text: int['b nt'] | list[str],
209
+ *,
210
+ lens: int['b'] | None = None,
211
+ noise_scheduler: str | None = None,
212
+ ):
213
+ # handle raw wave
214
+ if inp.ndim == 2:
215
+ inp = self.mel_spec(inp)
216
+ inp = rearrange(inp, 'b d n -> b n d')
217
+ assert inp.shape[-1] == self.num_channels
218
+
219
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
220
+
221
+ # handle text as string
222
+ if isinstance(text, list):
223
+ if exists(self.vocab_char_map):
224
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
225
+ else:
226
+ text = list_str_to_tensor(text).to(device)
227
+ assert text.shape[0] == batch
228
+
229
+ # lens and mask
230
+ if not exists(lens):
231
+ lens = torch.full((batch,), seq_len, device = device)
232
+
233
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
234
+
235
+ # get a random span to mask out for training conditionally
236
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
237
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
238
+
239
+ if exists(mask):
240
+ rand_span_mask &= mask
241
+
242
+ # mel is x1
243
+ x1 = inp
244
+
245
+ # x0 is gaussian noise
246
+ x0 = torch.randn_like(x1)
247
+
248
+ # time step
249
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
250
+ # TODO. noise_scheduler
251
+
252
+ # sample xt (φ_t(x) in the paper)
253
+ t = rearrange(time, 'b -> b 1 1')
254
+ φ = (1 - t) * x0 + t * x1
255
+ flow = x1 - x0
256
+
257
+ # only predict what is within the random mask span for infilling
258
+ cond = torch.where(
259
+ rand_span_mask[..., None],
260
+ torch.zeros_like(x1), x1
261
+ )
262
+
263
+ # transformer and cfg training with a drop rate
264
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
265
+ if random() < self.cond_drop_prob: # p_uncond in voicebox paper
266
+ drop_audio_cond = True
267
+ drop_text = True
268
+ else:
269
+ drop_text = False
270
+
271
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
272
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
273
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
274
+
275
+ # flow matching loss
276
+ loss = F.mse_loss(pred, flow, reduction = 'none')
277
+ loss = loss[rand_span_mask]
278
+
279
+ return loss.mean(), cond, pred
model/dataset.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, Sampler
8
+ import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
+ from datasets import Dataset as Dataset_
11
+ from typing import Union
12
+
13
+ from einops import rearrange
14
+
15
+ from model.modules import MelSpec
16
+
17
+
18
+ class HFDataset(Dataset):
19
+ def __init__(
20
+ self,
21
+ hf_dataset: Dataset,
22
+ target_sample_rate = 24_000,
23
+ n_mel_channels = 100,
24
+ hop_length = 256,
25
+ ):
26
+ self.data = hf_dataset
27
+ self.target_sample_rate = target_sample_rate
28
+ self.hop_length = hop_length
29
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
30
+
31
+ def get_frame_len(self, index):
32
+ row = self.data[index]
33
+ audio = row['audio']['array']
34
+ sample_rate = row['audio']['sampling_rate']
35
+ return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
36
+
37
+ def __len__(self):
38
+ return len(self.data)
39
+
40
+ def __getitem__(self, index):
41
+ row = self.data[index]
42
+ audio = row['audio']['array']
43
+
44
+ # logger.info(f"Audio shape: {audio.shape}")
45
+
46
+ sample_rate = row['audio']['sampling_rate']
47
+ duration = audio.shape[-1] / sample_rate
48
+
49
+ if duration > 30 or duration < 0.3:
50
+ return self.__getitem__((index + 1) % len(self.data))
51
+
52
+ audio_tensor = torch.from_numpy(audio).float()
53
+
54
+ if sample_rate != self.target_sample_rate:
55
+ resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
56
+ audio_tensor = resampler(audio_tensor)
57
+
58
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
59
+
60
+ mel_spec = self.mel_spectrogram(audio_tensor)
61
+
62
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
63
+
64
+ text = row['text']
65
+
66
+ return dict(
67
+ mel_spec = mel_spec,
68
+ text = text,
69
+ )
70
+
71
+ class CustomDataset(Dataset):
72
+ def __init__(
73
+ self,
74
+ custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
80
+ ):
81
+ self.data = custom_dataset
82
+ self.durations = durations
83
+ self.target_sample_rate = target_sample_rate
84
+ self.hop_length = hop_length
85
+ self.preprocessed_mel = preprocessed_mel
86
+ if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
88
+
89
+ def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
91
+ return self.durations[index] * self.target_sample_rate / self.hop_length
92
+ return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
+ def __len__(self):
95
+ return len(self.data)
96
+
97
+ def __getitem__(self, index):
98
+ row = self.data[index]
99
+ audio_path = row["audio_path"]
100
+ text = row["text"]
101
+ duration = row["duration"]
102
+
103
+ if self.preprocessed_mel:
104
+ mel_spec = torch.tensor(row["mel_spec"])
105
+
106
+ else:
107
+ audio, source_sample_rate = torchaudio.load(audio_path)
108
+
109
+ if duration > 30 or duration < 0.3:
110
+ return self.__getitem__((index + 1) % len(self.data))
111
+
112
+ if source_sample_rate != self.target_sample_rate:
113
+ resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
+ audio = resampler(audio)
115
+
116
+ mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
+ return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
+ )
123
+
124
+
125
+ # Dynamic Batch Sampler
126
+
127
+ class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
+ """
134
+
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
136
+ self.sampler = sampler
137
+ self.frames_threshold = frames_threshold
138
+ self.max_samples = max_samples
139
+
140
+ indices, batches = [], []
141
+ data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
144
+ indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
+
147
+ batch = []
148
+ batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
150
+ if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
+ batch.append(idx)
152
+ batch_frames += frame_len
153
+ else:
154
+ if len(batch) > 0:
155
+ batches.append(batch)
156
+ if frame_len <= self.frames_threshold:
157
+ batch = [idx]
158
+ batch_frames = frame_len
159
+ else:
160
+ batch = []
161
+ batch_frames = 0
162
+
163
+ if not drop_last and len(batch) > 0:
164
+ batches.append(batch)
165
+
166
+ del indices
167
+
168
+ # if want to have different batches between epochs, may just set a seed and log it in ckpt
169
+ # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
170
+ # e.g. for epoch n, use (random_seed + n)
171
+ random.seed(random_seed)
172
+ random.shuffle(batches)
173
+
174
+ self.batches = batches
175
+
176
+ def __iter__(self):
177
+ return iter(self.batches)
178
+
179
+ def __len__(self):
180
+ return len(self.batches)
181
+
182
+
183
+ # Load dataset
184
+ def load_dataset(
185
+ dataset_name: str,
186
+ tokenizer: str,
187
+ dataset_type: str = "CustomDataset",
188
+ audio_type: str = "raw",
189
+ mel_spec_kwargs: dict = dict()
190
+ ) -> Union[CustomDataset, HFDataset]:
191
+
192
+ print("Loading dataset ...")
193
+
194
+ if dataset_type == "CustomDataset":
195
+ if audio_type == "raw":
196
+ try:
197
+ train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
198
+ except:
199
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
200
+ preprocessed_mel = False
201
+ elif audio_type == "mel":
202
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
203
+ preprocessed_mel = True
204
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
205
+ data_dict = json.load(f)
206
+ durations = data_dict["duration"]
207
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
208
+
209
+ elif dataset_type == "HFDataset":
210
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
211
+ "May also the corresponding script cuz different dataset may have different format.")
212
+ pre, post = dataset_name.split("_")
213
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
214
+
215
+ return train_dataset
216
+
217
+
218
+ # collation
219
+
220
+ def collate_fn(batch):
221
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
222
+ mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
223
+ max_mel_length = mel_lengths.amax()
224
+
225
+ padded_mel_specs = []
226
+ for spec in mel_specs: # TODO. maybe records mask for attention here
227
+ padding = (0, max_mel_length - spec.size(-1))
228
+ padded_spec = F.pad(spec, padding, value = 0)
229
+ padded_mel_specs.append(padded_spec)
230
+
231
+ mel_specs = torch.stack(padded_mel_specs)
232
+
233
+ text = [item['text'] for item in batch]
234
+ text_lengths = torch.LongTensor([len(item) for item in text])
235
+
236
+ return dict(
237
+ mel = mel_specs,
238
+ mel_lengths = mel_lengths,
239
+ text = text,
240
+ text_lengths = text_lengths,
241
+ )
model/ecapa_tdnn.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
14
+
15
+ class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
+ in_channels == out_channels == channels
18
+ '''
19
+
20
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
+ super().__init__()
22
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
23
+ self.scale = scale
24
+ self.width = channels // scale
25
+ self.nums = scale if scale == 1 else scale - 1
26
+
27
+ self.convs = []
28
+ self.bns = []
29
+ for i in range(self.nums):
30
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
31
+ self.bns.append(nn.BatchNorm1d(self.width))
32
+ self.convs = nn.ModuleList(self.convs)
33
+ self.bns = nn.ModuleList(self.bns)
34
+
35
+ def forward(self, x):
36
+ out = []
37
+ spx = torch.split(x, self.width, 1)
38
+ for i in range(self.nums):
39
+ if i == 0:
40
+ sp = spx[i]
41
+ else:
42
+ sp = sp + spx[i]
43
+ # Order: conv -> relu -> bn
44
+ sp = self.convs[i](sp)
45
+ sp = self.bns[i](F.relu(sp))
46
+ out.append(sp)
47
+ if self.scale != 1:
48
+ out.append(spx[self.nums])
49
+ out = torch.cat(out, dim=1)
50
+
51
+ return out
52
+
53
+
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
56
+
57
+ class Conv1dReluBn(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
+ super().__init__()
60
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
+ self.bn = nn.BatchNorm1d(out_channels)
62
+
63
+ def forward(self, x):
64
+ return self.bn(F.relu(self.conv(x)))
65
+
66
+
67
+ ''' The SE connection of 1D case.
68
+ '''
69
+
70
+ class SE_Connect(nn.Module):
71
+ def __init__(self, channels, se_bottleneck_dim=128):
72
+ super().__init__()
73
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
74
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
75
+
76
+ def forward(self, x):
77
+ out = x.mean(dim=2)
78
+ out = F.relu(self.linear1(out))
79
+ out = torch.sigmoid(self.linear2(out))
80
+ out = x * out.unsqueeze(2)
81
+
82
+ return out
83
+
84
+
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
+
88
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
+ # return nn.Sequential(
90
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
91
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
92
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
93
+ # SE_Connect(channels)
94
+ # )
95
+
96
+ class SE_Res2Block(nn.Module):
97
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
+ super().__init__()
99
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
100
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
101
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
102
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
103
+
104
+ self.shortcut = None
105
+ if in_channels != out_channels:
106
+ self.shortcut = nn.Conv1d(
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ kernel_size=1,
110
+ )
111
+
112
+ def forward(self, x):
113
+ residual = x
114
+ if self.shortcut:
115
+ residual = self.shortcut(x)
116
+
117
+ x = self.Conv1dReluBn1(x)
118
+ x = self.Res2Conv1dReluBn(x)
119
+ x = self.Conv1dReluBn2(x)
120
+ x = self.SE_Connect(x)
121
+
122
+ return x + residual
123
+
124
+
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
127
+
128
+ class AttentiveStatsPool(nn.Module):
129
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
130
+ super().__init__()
131
+ self.global_context_att = global_context_att
132
+
133
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
134
+ if global_context_att:
135
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
136
+ else:
137
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
138
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
+
140
+ def forward(self, x):
141
+
142
+ if self.global_context_att:
143
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
145
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
146
+ else:
147
+ x_in = x
148
+
149
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
150
+ alpha = torch.tanh(self.linear1(x_in))
151
+ # alpha = F.relu(self.linear1(x_in))
152
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
153
+ mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
+ std = torch.sqrt(residuals.clamp(min=1e-9))
156
+ return torch.cat([mean, std], dim=1)
157
+
158
+
159
+ class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
162
+ super().__init__()
163
+
164
+ self.feat_type = feat_type
165
+ self.feature_selection = feature_selection
166
+ self.update_extract = update_extract
167
+ self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
+ try:
171
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
+
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
177
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
179
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
+
181
+ self.feat_num = self.get_feat_num()
182
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
+
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
+ for name, param in self.feature_extract.named_parameters():
187
+ for freeze_val in freeze_list:
188
+ if freeze_val in name:
189
+ param.requires_grad = False
190
+ break
191
+
192
+ if not self.update_extract:
193
+ for param in self.feature_extract.parameters():
194
+ param.requires_grad = False
195
+
196
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
197
+ # self.channels = [channels] * 4 + [channels * 3]
198
+ self.channels = [channels] * 4 + [1536]
199
+
200
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
204
+
205
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
+ cat_channels = channels * 3
207
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
209
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
+
212
+
213
+ def get_feat_num(self):
214
+ self.feature_extract.eval()
215
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
216
+ with torch.no_grad():
217
+ features = self.feature_extract(wav)
218
+ select_feature = features[self.feature_selection]
219
+ if isinstance(select_feature, (list, tuple)):
220
+ return len(select_feature)
221
+ else:
222
+ return 1
223
+
224
+ def get_feat(self, x):
225
+ if self.update_extract:
226
+ x = self.feature_extract([sample for sample in x])
227
+ else:
228
+ with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
+ else:
232
+ x = self.feature_extract([sample for sample in x])
233
+
234
+ if self.feat_type == 'fbank':
235
+ x = x.log()
236
+
237
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
238
+ x = x[self.feature_selection]
239
+ if isinstance(x, (list, tuple)):
240
+ x = torch.stack(x, dim=0)
241
+ else:
242
+ x = x.unsqueeze(0)
243
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
244
+ x = (norm_weights * x).sum(dim=0)
245
+ x = torch.transpose(x, 1, 2) + 1e-6
246
+
247
+ x = self.instance_norm(x)
248
+ return x
249
+
250
+ def forward(self, x):
251
+ x = self.get_feat(x)
252
+
253
+ out1 = self.layer1(x)
254
+ out2 = self.layer2(out1)
255
+ out3 = self.layer3(out2)
256
+ out4 = self.layer4(out3)
257
+
258
+ out = torch.cat([out2, out3, out4], dim=1)
259
+ out = F.relu(self.conv(out))
260
+ out = self.bn(self.pooling(out))
261
+ out = self.linear(out)
262
+
263
+ return out
264
+
265
+
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
model/modules.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+
19
+ from einops import rearrange
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+
25
+ class MelSpec(nn.Module):
26
+ def __init__(
27
+ self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
+ ):
38
+ super().__init__()
39
+ self.n_mel_channels = n_mel_channels
40
+
41
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
+ )
52
+
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
+
55
+ def forward(self, inp):
56
+ if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
+
59
+ assert len(inp.shape) == 2
60
+
61
+ if self.dummy.device != inp.device:
62
+ self.to(inp.device)
63
+
64
+ mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
+ return mel
67
+
68
+
69
+ # sinusoidal position embedding
70
+
71
+ class SinusPositionEmbedding(nn.Module):
72
+ def __init__(self, dim):
73
+ super().__init__()
74
+ self.dim = dim
75
+
76
+ def forward(self, x, scale=1000):
77
+ device = x.device
78
+ half_dim = self.dim // 2
79
+ emb = math.log(10000) / (half_dim - 1)
80
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
+ return emb
84
+
85
+
86
+ # convolutional position embedding
87
+
88
+ class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
+ super().__init__()
91
+ assert kernel_size % 2 != 0
92
+ self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
+ nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
+ nn.Mish(),
97
+ )
98
+
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
+ if mask is not None:
101
+ mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
+
104
+ x = rearrange(x, 'b n d -> b d n')
105
+ x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
+
108
+ if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
+
111
+ return out
112
+
113
+
114
+ # rotary positional embedding related
115
+
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
117
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
+ # has some connection to NTK literature
119
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
120
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
121
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
122
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
123
+ t = torch.arange(end, device=freqs.device) # type: ignore
124
+ freqs = torch.outer(t, freqs).float() # type: ignore
125
+ freqs_cos = torch.cos(freqs) # real part
126
+ freqs_sin = torch.sin(freqs) # imaginary part
127
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
+
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
130
+ # length = length if isinstance(length, int) else length.max()
131
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
135
+ # avoid extra long error.
136
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
+ return pos
138
+
139
+
140
+ # Global Response Normalization layer (Instance Normalization ?)
141
+
142
+ class GRN(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
146
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
147
+
148
+ def forward(self, x):
149
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
150
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
151
+ return self.gamma * (x * Nx) + self.beta + x
152
+
153
+
154
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
+
157
+ class ConvNeXtV2Block(nn.Module):
158
+ def __init__(
159
+ self,
160
+ dim: int,
161
+ intermediate_dim: int,
162
+ dilation: int = 1,
163
+ ):
164
+ super().__init__()
165
+ padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
167
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
168
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
+ self.act = nn.GELU()
170
+ self.grn = GRN(intermediate_dim)
171
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ residual = x
175
+ x = x.transpose(1, 2) # b n d -> b d n
176
+ x = self.dwconv(x)
177
+ x = x.transpose(1, 2) # b d n -> b n d
178
+ x = self.norm(x)
179
+ x = self.pwconv1(x)
180
+ x = self.act(x)
181
+ x = self.grn(x)
182
+ x = self.pwconv2(x)
183
+ return residual + x
184
+
185
+
186
+ # AdaLayerNormZero
187
+ # return with modulated x for attn input, and params for later mlp modulation
188
+
189
+ class AdaLayerNormZero(nn.Module):
190
+ def __init__(self, dim):
191
+ super().__init__()
192
+
193
+ self.silu = nn.SiLU()
194
+ self.linear = nn.Linear(dim, dim * 6)
195
+
196
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
+
198
+ def forward(self, x, emb = None):
199
+ emb = self.linear(self.silu(emb))
200
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
+
202
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
203
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
204
+
205
+
206
+ # AdaLayerNormZero for final layer
207
+ # return only with modulated x for attn input, cuz no more mlp modulation
208
+
209
+ class AdaLayerNormZero_Final(nn.Module):
210
+ def __init__(self, dim):
211
+ super().__init__()
212
+
213
+ self.silu = nn.SiLU()
214
+ self.linear = nn.Linear(dim, dim * 2)
215
+
216
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
217
+
218
+ def forward(self, x, emb):
219
+ emb = self.linear(self.silu(emb))
220
+ scale, shift = torch.chunk(emb, 2, dim=1)
221
+
222
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
223
+ return x
224
+
225
+
226
+ # FeedForward
227
+
228
+ class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
+ super().__init__()
231
+ inner_dim = int(dim * mult)
232
+ dim_out = dim_out if dim_out is not None else dim
233
+
234
+ activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
+
245
+ def forward(self, x):
246
+ return self.ff(x)
247
+
248
+
249
+ # Attention with possible joint part
250
+ # modified from diffusers/src/diffusers/models/attention_processor.py
251
+
252
+ class Attention(nn.Module):
253
+ def __init__(
254
+ self,
255
+ processor: JointAttnProcessor | AttnProcessor,
256
+ dim: int,
257
+ heads: int = 8,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
+ ):
263
+ super().__init__()
264
+
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
267
+
268
+ self.processor = processor
269
+
270
+ self.dim = dim
271
+ self.heads = heads
272
+ self.inner_dim = dim_head * heads
273
+ self.dropout = dropout
274
+
275
+ self.context_dim = context_dim
276
+ self.context_pre_only = context_pre_only
277
+
278
+ self.to_q = nn.Linear(dim, self.inner_dim)
279
+ self.to_k = nn.Linear(dim, self.inner_dim)
280
+ self.to_v = nn.Linear(dim, self.inner_dim)
281
+
282
+ if self.context_dim is not None:
283
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
284
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
285
+ if self.context_pre_only is not None:
286
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
287
+
288
+ self.to_out = nn.ModuleList([])
289
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
290
+ self.to_out.append(nn.Dropout(dropout))
291
+
292
+ if self.context_pre_only is not None and not self.context_pre_only:
293
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
294
+
295
+ def forward(
296
+ self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
+ ) -> torch.Tensor:
303
+ if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
+ else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
+
308
+
309
+ # Attention processor
310
+
311
+ class AttnProcessor:
312
+ def __init__(self):
313
+ pass
314
+
315
+ def __call__(
316
+ self,
317
+ attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
+ ) -> torch.FloatTensor:
322
+
323
+ batch_size = x.shape[0]
324
+
325
+ # `sample` projections.
326
+ query = attn.to_q(x)
327
+ key = attn.to_k(x)
328
+ value = attn.to_v(x)
329
+
330
+ # apply rotary position embedding
331
+ if rope is not None:
332
+ freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
+
335
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
337
+
338
+ # attention
339
+ inner_dim = key.shape[-1]
340
+ head_dim = inner_dim // attn.heads
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+
345
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
346
+ if mask is not None:
347
+ attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
+ else:
351
+ attn_mask = None
352
+
353
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
354
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
+ x = x.to(query.dtype)
356
+
357
+ # linear proj
358
+ x = attn.to_out[0](x)
359
+ # dropout
360
+ x = attn.to_out[1](x)
361
+
362
+ if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
+
366
+ return x
367
+
368
+
369
+ # Joint Attention processor for MM-DiT
370
+ # modified from diffusers/src/diffusers/models/attention_processor.py
371
+
372
+ class JointAttnProcessor:
373
+ def __init__(self):
374
+ pass
375
+
376
+ def __call__(
377
+ self,
378
+ attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
+ ) -> torch.FloatTensor:
385
+ residual = x
386
+
387
+ batch_size = c.shape[0]
388
+
389
+ # `sample` projections.
390
+ query = attn.to_q(x)
391
+ key = attn.to_k(x)
392
+ value = attn.to_v(x)
393
+
394
+ # `context` projections.
395
+ c_query = attn.to_q_c(c)
396
+ c_key = attn.to_k_c(c)
397
+ c_value = attn.to_v_c(c)
398
+
399
+ # apply rope for context and noised input independently
400
+ if rope is not None:
401
+ freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
+ if c_rope is not None:
406
+ freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
+
411
+ # attention
412
+ query = torch.cat([query, c_query], dim=1)
413
+ key = torch.cat([key, c_key], dim=1)
414
+ value = torch.cat([value, c_value], dim=1)
415
+
416
+ inner_dim = key.shape[-1]
417
+ head_dim = inner_dim // attn.heads
418
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
419
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
421
+
422
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
423
+ if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
+ else:
428
+ attn_mask = None
429
+
430
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
431
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
432
+ x = x.to(query.dtype)
433
+
434
+ # Split the attention outputs.
435
+ x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
+ )
439
+
440
+ # linear proj
441
+ x = attn.to_out[0](x)
442
+ # dropout
443
+ x = attn.to_out[1](x)
444
+ if not attn.context_pre_only:
445
+ c = attn.to_out_c(c)
446
+
447
+ if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
+
452
+ return x, c
453
+
454
+
455
+ # DiT Block
456
+
457
+ class DiTBlock(nn.Module):
458
+
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
+ self.attn_norm = AdaLayerNormZero(dim)
463
+ self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
+
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
+ # pre-norm & modulation for attention input
476
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
+
478
+ # attention
479
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
480
+
481
+ # process attention output for input x
482
+ x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
+ ff_output = self.ff(norm)
486
+ x = x + gate_mlp.unsqueeze(1) * ff_output
487
+
488
+ return x
489
+
490
+
491
+ # MMDiT Block https://arxiv.org/abs/2403.03206
492
+
493
+ class MMDiTBlock(nn.Module):
494
+ r"""
495
+ modified from diffusers/src/diffusers/models/attention.py
496
+
497
+ notes.
498
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
499
+ _x: noised input related. (right part)
500
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
+ """
502
+
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
+ super().__init__()
505
+
506
+ self.context_pre_only = context_pre_only
507
+
508
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
+ self.attn_norm_x = AdaLayerNormZero(dim)
510
+ self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
+
520
+ if not context_pre_only:
521
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
+ else:
524
+ self.ff_norm_c = None
525
+ self.ff_c = None
526
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
+
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
+ # pre-norm & modulation for attention input
531
+ if self.context_pre_only:
532
+ norm_c = self.attn_norm_c(c, t)
533
+ else:
534
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
535
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
536
+
537
+ # attention
538
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
539
+
540
+ # process attention output for context c
541
+ if self.context_pre_only:
542
+ c = None
543
+ else: # if not last layer
544
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
+
546
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
547
+ c_ff_output = self.ff_c(norm_c)
548
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
549
+
550
+ # process attention output for input x
551
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
+ x_ff_output = self.ff_x(norm_x)
555
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
556
+
557
+ return c, x
558
+
559
+
560
+ # time step conditioning embedding
561
+
562
+ class TimestepEmbedding(nn.Module):
563
+ def __init__(self, dim, freq_embed_dim=256):
564
+ super().__init__()
565
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
+
572
+ def forward(self, timestep: float['b']):
573
+ time_hidden = self.time_embed(timestep)
574
+ time = self.time_mlp(time_hidden) # b d
575
+ return time
model/trainer.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
+
13
+ from einops import rearrange
14
+
15
+ from accelerate import Accelerator
16
+ from accelerate.utils import DistributedDataParallelKwargs
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from model import CFM
21
+ from model.utils import exists, default
22
+ from model.dataset import DynamicBatchSampler, collate_fn
23
+
24
+
25
+ # trainer
26
+
27
+ class Trainer:
28
+ def __init__(
29
+ self,
30
+ model: CFM,
31
+ epochs,
32
+ learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
+ batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
+ noise_scheduler: str | None = None,
42
+ duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
+ wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
+ accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
49
+ ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
+
53
+ self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
+ )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
68
+ "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
+ "batch_size": batch_size,
71
+ "batch_size_type": batch_size_type,
72
+ "max_samples": max_samples,
73
+ "grad_accumulation_steps": grad_accumulation_steps,
74
+ "max_grad_norm": max_grad_norm,
75
+ "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
77
+ )
78
+
79
+ self.model = model
80
+
81
+ if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
+
88
+ self.ema_model.to(self.accelerator.device)
89
+
90
+ self.epochs = epochs
91
+ self.num_warmup_updates = num_warmup_updates
92
+ self.save_per_updates = save_per_updates
93
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
+
96
+ self.batch_size = batch_size
97
+ self.batch_size_type = batch_size_type
98
+ self.max_samples = max_samples
99
+ self.grad_accumulation_steps = grad_accumulation_steps
100
+ self.max_grad_norm = max_grad_norm
101
+
102
+ self.noise_scheduler = noise_scheduler
103
+
104
+ self.duration_predictor = duration_predictor
105
+
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
110
+
111
+ @property
112
+ def is_main(self):
113
+ return self.accelerator.is_main_process
114
+
115
+ def save_checkpoint(self, step, last=False):
116
+ self.accelerator.wait_for_everyone()
117
+ if self.is_main:
118
+ checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
+ )
125
+ if not os.path.exists(self.checkpoint_path):
126
+ os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
+ print(f"Saved last checkpoint at step {step}")
130
+ else:
131
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
+
133
+ def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
+ return 0
136
+
137
+ self.accelerator.wait_for_everyone()
138
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
139
+ latest_checkpoint = "model_last.pt"
140
+ else:
141
+ latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
145
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
146
+
147
+ if self.is_main:
148
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
149
+
150
+ if self.scheduler:
151
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
152
+
153
+ step = checkpoint['step']
154
+ del checkpoint; gc.collect()
155
+ return step
156
+
157
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
158
+
159
+ if exists(resumable_with_seed):
160
+ generator = torch.Generator()
161
+ generator.manual_seed(resumable_with_seed)
162
+ else:
163
+ generator = None
164
+
165
+ if self.batch_size_type == "sample":
166
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
167
+ batch_size=self.batch_size, shuffle=True, generator=generator)
168
+ elif self.batch_size_type == "frame":
169
+ self.accelerator.even_batches = False
170
+ sampler = SequentialSampler(train_dataset)
171
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
172
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True,
173
+ batch_sampler=batch_sampler)
174
+ else:
175
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
176
+
177
+ # accelerator.prepare() dispatches batches to devices;
178
+ # which means the length of dataloader calculated before, should consider the number of devices
179
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
180
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
181
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
182
+ decay_steps = total_steps - warmup_steps
183
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
184
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
185
+ self.scheduler = SequentialLR(self.optimizer,
186
+ schedulers=[warmup_scheduler, decay_scheduler],
187
+ milestones=[warmup_steps])
188
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
189
+ start_step = self.load_checkpoint()
190
+ global_step = start_step
191
+
192
+ if exists(resumable_with_seed):
193
+ orig_epoch_step = len(train_dataloader)
194
+ skipped_epoch = int(start_step // orig_epoch_step)
195
+ skipped_batch = start_step % orig_epoch_step
196
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
197
+ else:
198
+ skipped_epoch = 0
199
+
200
+ for epoch in range(skipped_epoch, self.epochs):
201
+ self.model.train()
202
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
203
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
204
+ initial=skipped_batch, total=orig_epoch_step)
205
+ else:
206
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
207
+
208
+ for batch in progress_bar:
209
+ with self.accelerator.accumulate(self.model):
210
+ text_inputs = batch['text']
211
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
212
+ mel_lengths = batch["mel_lengths"]
213
+
214
+ # TODO. add duration predictor training
215
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
216
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
217
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
218
+
219
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
220
+ self.accelerator.backward(loss)
221
+
222
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
223
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
224
+
225
+ self.optimizer.step()
226
+ self.scheduler.step()
227
+ self.optimizer.zero_grad()
228
+
229
+ if self.is_main:
230
+ self.ema_model.update()
231
+
232
+ global_step += 1
233
+
234
+ if self.accelerator.is_local_main_process:
235
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
236
+
237
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
238
+
239
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
240
+ self.save_checkpoint(global_step)
241
+
242
+ if global_step % self.last_per_steps == 0:
243
+ self.save_checkpoint(global_step, last=True)
244
+
245
+ self.accelerator.end_training()
model/utils.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import math
6
+ import random
7
+ import string
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pylab as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ import torchaudio
19
+
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
+ import jieba
24
+ from pypinyin import lazy_pinyin, Style
25
+ import zhconv
26
+ from zhon.hanzi import punctuation
27
+ from jiwer import compute_measures
28
+
29
+ from funasr import AutoModel
30
+ from faster_whisper import WhisperModel
31
+
32
+ from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
+ from model.modules import MelSpec
34
+
35
+
36
+ # seed everything
37
+
38
+ def seed_everything(seed = 0):
39
+ random.seed(seed)
40
+ os.environ['PYTHONHASHSEED'] = str(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed(seed)
43
+ torch.cuda.manual_seed_all(seed)
44
+ torch.backends.cudnn.deterministic = True
45
+ torch.backends.cudnn.benchmark = False
46
+
47
+ # helpers
48
+
49
+ def exists(v):
50
+ return v is not None
51
+
52
+ def default(v, d):
53
+ return v if exists(v) else d
54
+
55
+ # tensor helpers
56
+
57
+ def lens_to_mask(
58
+ t: int['b'],
59
+ length: int | None = None
60
+ ) -> bool['b n']:
61
+
62
+ if not exists(length):
63
+ length = t.amax()
64
+
65
+ seq = torch.arange(length, device = t.device)
66
+ return einx.less('n, b -> b n', seq, t)
67
+
68
+ def mask_from_start_end_indices(
69
+ seq_len: int['b'],
70
+ start: int['b'],
71
+ end: int['b']
72
+ ):
73
+ max_seq_len = seq_len.max().item()
74
+ seq = torch.arange(max_seq_len, device = start.device).long()
75
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
76
+
77
+ def mask_from_frac_lengths(
78
+ seq_len: int['b'],
79
+ frac_lengths: float['b']
80
+ ):
81
+ lengths = (frac_lengths * seq_len).long()
82
+ max_start = seq_len - lengths
83
+
84
+ rand = torch.rand_like(frac_lengths)
85
+ start = (max_start * rand).long().clamp(min = 0)
86
+ end = start + lengths
87
+
88
+ return mask_from_start_end_indices(seq_len, start, end)
89
+
90
+ def maybe_masked_mean(
91
+ t: float['b n d'],
92
+ mask: bool['b n'] = None
93
+ ) -> float['b d']:
94
+
95
+ if not exists(mask):
96
+ return t.mean(dim = 1)
97
+
98
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
99
+ num = reduce(t, 'b n d -> b d', 'sum')
100
+ den = reduce(mask.float(), 'b n -> b', 'sum')
101
+
102
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
103
+
104
+
105
+ # simple utf-8 tokenizer, since paper went character based
106
+ def list_str_to_tensor(
107
+ text: list[str],
108
+ padding_value = -1
109
+ ) -> int['b nt']:
110
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
111
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
112
+ return text
113
+
114
+ # char tokenizer, based on custom dataset's extracted .txt file
115
+ def list_str_to_idx(
116
+ text: list[str] | list[list[str]],
117
+ vocab_char_map: dict[str, int], # {char: idx}
118
+ padding_value = -1
119
+ ) -> int['b nt']:
120
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
121
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
122
+ return text
123
+
124
+
125
+ # Get tokenizer
126
+
127
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
128
+ '''
129
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
130
+ - "char" for char-wise tokenizer, need .txt vocab_file
131
+ - "byte" for utf-8 tokenizer
132
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
133
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
134
+ - if use "byte", set to 256 (unicode byte range)
135
+ '''
136
+ if tokenizer in ["pinyin", "char"]:
137
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
138
+ vocab_char_map = {}
139
+ for i, char in enumerate(f):
140
+ vocab_char_map[char[:-1]] = i
141
+ vocab_size = len(vocab_char_map)
142
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
143
+
144
+ elif tokenizer == "byte":
145
+ vocab_char_map = None
146
+ vocab_size = 256
147
+
148
+ return vocab_char_map, vocab_size
149
+
150
+
151
+ # convert char to pinyin
152
+
153
+ def convert_char_to_pinyin(text_list, polyphone = True):
154
+ final_text_list = []
155
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
156
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
157
+ for text in text_list:
158
+ char_list = []
159
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
160
+ text = text.translate(custom_trans)
161
+ for seg in jieba.cut(text):
162
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
163
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
164
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
165
+ char_list.append(" ")
166
+ char_list.extend(seg)
167
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
168
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
169
+ for c in seg:
170
+ if c not in "。,、;:?!《》【】—…":
171
+ char_list.append(" ")
172
+ char_list.append(c)
173
+ else: # if mixed chinese characters, alphabets and symbols
174
+ for c in seg:
175
+ if ord(c) < 256:
176
+ char_list.extend(c)
177
+ else:
178
+ if c not in "。,、;:?!《》【】—…":
179
+ char_list.append(" ")
180
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
181
+ else: # if is zh punc
182
+ char_list.append(c)
183
+ final_text_list.append(char_list)
184
+
185
+ return final_text_list
186
+
187
+
188
+ # save spectrogram
189
+ def save_spectrogram(spectrogram, path):
190
+ plt.figure(figsize=(12, 4))
191
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
192
+ plt.colorbar()
193
+ plt.savefig(path)
194
+ plt.close()
195
+
196
+
197
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
198
+ def get_seedtts_testset_metainfo(metalst):
199
+ f = open(metalst); lines = f.readlines(); f.close()
200
+ metainfo = []
201
+ for line in lines:
202
+ if len(line.strip().split('|')) == 5:
203
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
204
+ elif len(line.strip().split('|')) == 4:
205
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
206
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
207
+ if not os.path.isabs(prompt_wav):
208
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
209
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
210
+ return metainfo
211
+
212
+
213
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
214
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
215
+ f = open(metalst); lines = f.readlines(); f.close()
216
+ metainfo = []
217
+ for line in lines:
218
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
219
+
220
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
221
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
222
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
223
+
224
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
225
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
226
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
227
+
228
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
229
+
230
+ return metainfo
231
+
232
+
233
+ # padded to max length mel batch
234
+ def padded_mel_batch(ref_mels):
235
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
236
+ padded_ref_mels = []
237
+ for mel in ref_mels:
238
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
239
+ padded_ref_mels.append(padded_ref_mel)
240
+ padded_ref_mels = torch.stack(padded_ref_mels)
241
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
242
+ return padded_ref_mels
243
+
244
+
245
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
246
+
247
+ def get_inference_prompt(
248
+ metainfo,
249
+ speed = 1., tokenizer = "pinyin", polyphone = True,
250
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
251
+ use_truth_duration = False,
252
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
253
+ ):
254
+ prompts_all = []
255
+
256
+ min_tokens = min_secs * target_sample_rate // hop_length
257
+ max_tokens = max_secs * target_sample_rate // hop_length
258
+
259
+ batch_accum = [0] * num_buckets
260
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
261
+ ([[] for _ in range(num_buckets)] for _ in range(6))
262
+
263
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
264
+
265
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
266
+
267
+ # Audio
268
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
269
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
270
+ if ref_rms < target_rms:
271
+ ref_audio = ref_audio * target_rms / ref_rms
272
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
273
+ if ref_sr != target_sample_rate:
274
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
275
+ ref_audio = resampler(ref_audio)
276
+
277
+ # Text
278
+ text = [prompt_text + gt_text]
279
+ if tokenizer == "pinyin":
280
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
281
+ else:
282
+ text_list = text
283
+
284
+ # Duration, mel frame length
285
+ ref_mel_len = ref_audio.shape[-1] // hop_length
286
+ if use_truth_duration:
287
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
288
+ if gt_sr != target_sample_rate:
289
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
290
+ gt_audio = resampler(gt_audio)
291
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
292
+
293
+ # # test vocoder resynthesis
294
+ # ref_audio = gt_audio
295
+ else:
296
+ zh_pause_punc = r"。,、;:?!"
297
+ ref_text_len = len(prompt_text) + len(re.findall(zh_pause_punc, prompt_text))
298
+ gen_text_len = len(gt_text) + len(re.findall(zh_pause_punc, gt_text))
299
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
300
+
301
+ # to mel spectrogram
302
+ ref_mel = mel_spectrogram(ref_audio)
303
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
304
+
305
+ # deal with batch
306
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
307
+ assert min_tokens <= total_mel_len <= max_tokens, \
308
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
309
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
310
+
311
+ utts[bucket_i].append(utt)
312
+ ref_rms_list[bucket_i].append(ref_rms)
313
+ ref_mels[bucket_i].append(ref_mel)
314
+ ref_mel_lens[bucket_i].append(ref_mel_len)
315
+ total_mel_lens[bucket_i].append(total_mel_len)
316
+ final_text_list[bucket_i].extend(text_list)
317
+
318
+ batch_accum[bucket_i] += total_mel_len
319
+
320
+ if batch_accum[bucket_i] >= infer_batch_size:
321
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
322
+ prompts_all.append((
323
+ utts[bucket_i],
324
+ ref_rms_list[bucket_i],
325
+ padded_mel_batch(ref_mels[bucket_i]),
326
+ ref_mel_lens[bucket_i],
327
+ total_mel_lens[bucket_i],
328
+ final_text_list[bucket_i]
329
+ ))
330
+ batch_accum[bucket_i] = 0
331
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
332
+
333
+ # add residual
334
+ for bucket_i, bucket_frames in enumerate(batch_accum):
335
+ if bucket_frames > 0:
336
+ prompts_all.append((
337
+ utts[bucket_i],
338
+ ref_rms_list[bucket_i],
339
+ padded_mel_batch(ref_mels[bucket_i]),
340
+ ref_mel_lens[bucket_i],
341
+ total_mel_lens[bucket_i],
342
+ final_text_list[bucket_i]
343
+ ))
344
+ # not only leave easy work for last workers
345
+ random.seed(666)
346
+ random.shuffle(prompts_all)
347
+
348
+ return prompts_all
349
+
350
+
351
+ # get wav_res_ref_text of seed-tts test metalst
352
+ # https://github.com/BytedanceSpeech/seed-tts-eval
353
+
354
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
355
+ f = open(metalst)
356
+ lines = f.readlines()
357
+ f.close()
358
+
359
+ test_set_ = []
360
+ for line in tqdm(lines):
361
+ if len(line.strip().split('|')) == 5:
362
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
363
+ elif len(line.strip().split('|')) == 4:
364
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
365
+
366
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
367
+ continue
368
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
369
+ if not os.path.isabs(prompt_wav):
370
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
371
+
372
+ test_set_.append((gen_wav, prompt_wav, gt_text))
373
+
374
+ num_jobs = len(gpus)
375
+ if num_jobs == 1:
376
+ return [(gpus[0], test_set_)]
377
+
378
+ wav_per_job = len(test_set_) // num_jobs + 1
379
+ test_set = []
380
+ for i in range(num_jobs):
381
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
382
+
383
+ return test_set
384
+
385
+
386
+ # get librispeech test-clean cross sentence test
387
+
388
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
389
+ f = open(metalst)
390
+ lines = f.readlines()
391
+ f.close()
392
+
393
+ test_set_ = []
394
+ for line in tqdm(lines):
395
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
396
+
397
+ if eval_ground_truth:
398
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
399
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
400
+ else:
401
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
402
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
403
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
404
+
405
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
406
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
407
+
408
+ test_set_.append((gen_wav, ref_wav, gen_txt))
409
+
410
+ num_jobs = len(gpus)
411
+ if num_jobs == 1:
412
+ return [(gpus[0], test_set_)]
413
+
414
+ wav_per_job = len(test_set_) // num_jobs + 1
415
+ test_set = []
416
+ for i in range(num_jobs):
417
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
418
+
419
+ return test_set
420
+
421
+
422
+ # load asr model
423
+
424
+ def load_asr_model(lang, ckpt_dir = ""):
425
+ if lang == "zh":
426
+ model = AutoModel(
427
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
428
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
429
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
430
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
431
+ disable_update=True,
432
+ ) # following seed-tts setting
433
+ elif lang == "en":
434
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
435
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
436
+ return model
437
+
438
+
439
+ # WER Evaluation, the way Seed-TTS does
440
+
441
+ def run_asr_wer(args):
442
+ rank, lang, test_set, ckpt_dir = args
443
+
444
+ if lang == "zh":
445
+ torch.cuda.set_device(rank)
446
+ elif lang == "en":
447
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
448
+ else:
449
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
450
+
451
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
452
+
453
+ punctuation_all = punctuation + string.punctuation
454
+ wers = []
455
+
456
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
457
+ if lang == "zh":
458
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
459
+ hypo = res[0]["text"]
460
+ hypo = zhconv.convert(hypo, 'zh-cn')
461
+ elif lang == "en":
462
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
463
+ hypo = ''
464
+ for segment in segments:
465
+ hypo = hypo + ' ' + segment.text
466
+
467
+ # raw_truth = truth
468
+ # raw_hypo = hypo
469
+
470
+ for x in punctuation_all:
471
+ truth = truth.replace(x, '')
472
+ hypo = hypo.replace(x, '')
473
+
474
+ truth = truth.replace(' ', ' ')
475
+ hypo = hypo.replace(' ', ' ')
476
+
477
+ if lang == "zh":
478
+ truth = " ".join([x for x in truth])
479
+ hypo = " ".join([x for x in hypo])
480
+ elif lang == "en":
481
+ truth = truth.lower()
482
+ hypo = hypo.lower()
483
+
484
+ measures = compute_measures(truth, hypo)
485
+ wer = measures["wer"]
486
+
487
+ # ref_list = truth.split(" ")
488
+ # subs = measures["substitutions"] / len(ref_list)
489
+ # dele = measures["deletions"] / len(ref_list)
490
+ # inse = measures["insertions"] / len(ref_list)
491
+
492
+ wers.append(wer)
493
+
494
+ return wers
495
+
496
+
497
+ # SIM Evaluation
498
+
499
+ def run_sim(args):
500
+ rank, test_set, ckpt_dir = args
501
+ device = f"cuda:{rank}"
502
+
503
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
504
+ state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
505
+ model.load_state_dict(state_dict['model'], strict=False)
506
+
507
+ use_gpu=True if torch.cuda.is_available() else False
508
+ if use_gpu:
509
+ model = model.cuda(device)
510
+ model.eval()
511
+
512
+ sim_list = []
513
+ for wav1, wav2, truth in tqdm(test_set):
514
+
515
+ wav1, sr1 = torchaudio.load(wav1)
516
+ wav2, sr2 = torchaudio.load(wav2)
517
+
518
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
519
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
520
+ wav1 = resample1(wav1)
521
+ wav2 = resample2(wav2)
522
+
523
+ if use_gpu:
524
+ wav1 = wav1.cuda(device)
525
+ wav2 = wav2.cuda(device)
526
+ with torch.no_grad():
527
+ emb1 = model(wav1)
528
+ emb2 = model(wav2)
529
+
530
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
531
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
532
+ sim_list.append(sim)
533
+
534
+ return sim_list
535
+
536
+
537
+ # filter func for dirty data with many repetitions
538
+
539
+ def repetition_found(text, length = 2, tolerance = 10):
540
+ pattern_count = defaultdict(int)
541
+ for i in range(len(text) - length + 1):
542
+ pattern = text[i:i + length]
543
+ pattern_count[pattern] += 1
544
+ for pattern, count in pattern_count.items():
545
+ if count > tolerance:
546
+ return True
547
+ return False
548
+
549
+
550
+ # load model checkpoint for inference
551
+
552
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
553
+ from ema_pytorch import EMA
554
+
555
+ ckpt_type = ckpt_path.split(".")[-1]
556
+ if ckpt_type == "safetensors":
557
+ from safetensors.torch import load_file
558
+ checkpoint = load_file(ckpt_path, device=device)
559
+ else:
560
+ checkpoint = torch.load(ckpt_path, map_location=device)
561
+
562
+ if use_ema == True:
563
+ ema_model = EMA(model, include_online_model = False).to(device)
564
+ if ckpt_type == "safetensors":
565
+ ema_model.load_state_dict(checkpoint)
566
+ else:
567
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
568
+ ema_model.copy_params_from_ema_to_model()
569
+ else:
570
+ model.load_state_dict(checkpoint['model_state_dict'])
571
+
572
+ return model
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ datasets
3
+ einops>=0.8.0
4
+ einx>=0.3.0
5
+ ema_pytorch>=0.5.2
6
+ faster_whisper
7
+ funasr
8
+ jieba
9
+ jiwer
10
+ librosa
11
+ matplotlib
12
+ pypinyin
13
+ safetensors
14
+ # torch>=2.0
15
+ # torchaudio>=2.3.0
16
+ torchdiffeq
17
+ tqdm>=4.65.0
18
+ transformers
19
+ vocos
20
+ wandb
21
+ x_transformers>=1.31.14
22
+ zhconv
23
+ zhon
24
+ gradio
25
+ cached-path
requirements_gradio.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ cached_path
2
+ click
3
+ gradio
4
+ pydub
scripts/count_max_epoch.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
4
+
5
+ # data
6
+ total_hours = 95282
7
+ mel_hop_length = 256
8
+ mel_sampling_rate = 24000
9
+
10
+ # target
11
+ wanted_max_updates = 1000000
12
+
13
+ # train params
14
+ gpus = 8
15
+ frames_per_gpu = 38400 # 8 * 38400 = 307200
16
+ grad_accum = 1
17
+
18
+ # intermediate
19
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
20
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
21
+ updates_per_epoch = total_hours / mini_batch_hours
22
+ steps_per_epoch = updates_per_epoch * grad_accum
23
+
24
+ # result
25
+ epochs = wanted_max_updates / updates_per_epoch
26
+ print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
27
+ print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
28
+ print(f" or approx. 0/{steps_per_epoch:.0f} steps")
29
+
30
+ # others
31
+ print(f"total {total_hours:.0f} hours")
32
+ print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
scripts/count_params_gflops.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
+
6
+ import torch
7
+ import thop
8
+
9
+
10
+ ''' ~155M '''
11
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
14
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
15
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
+ # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
+
18
+ ''' ~335M '''
19
+ # FLOPs: 622.1 G, Params: 333.2 M
20
+ # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
+ # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
+
24
+
25
+ model = M2_TTS(transformer=transformer)
26
+ target_sample_rate = 24000
27
+ n_mel_channels = 100
28
+ hop_length = 256
29
+ duration = 20
30
+ frame_length = int(duration * target_sample_rate / hop_length)
31
+ text_length = 150
32
+
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
34
+ print(f"FLOPs: {flops / 1e9} G")
35
+ print(f"Params: {params / 1e6} M")
scripts/eval_librispeech_test_clean.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_librispeech_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "en"
18
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
19
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
+
22
+ gpus = [0,1,2,3,4,5,6,7]
23
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
+
25
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
26
+ ## leading to a low similarity for the ground truth in some cases.
27
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
28
+
29
+ local = False
30
+ if local: # use local custom checkpoint dir
31
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
32
+ else:
33
+ asr_ckpt_dir = "" # auto download to cache dir
34
+
35
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
36
+
37
+
38
+ # --------------------------- WER ---------------------------
39
+
40
+ if eval_task == "wer":
41
+ wers = []
42
+
43
+ with mp.Pool(processes=len(gpus)) as pool:
44
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
45
+ results = pool.map(run_asr_wer, args)
46
+ for wers_ in results:
47
+ wers.extend(wers_)
48
+
49
+ wer = round(np.mean(wers)*100, 3)
50
+ print(f"\nTotal {len(wers)} samples")
51
+ print(f"WER : {wer}%")
52
+
53
+
54
+ # --------------------------- SIM ---------------------------
55
+
56
+ if eval_task == "sim":
57
+ sim_list = []
58
+
59
+ with mp.Pool(processes=len(gpus)) as pool:
60
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
61
+ results = pool.map(run_sim, args)
62
+ for sim_ in results:
63
+ sim_list.extend(sim_)
64
+
65
+ sim = round(sum(sim_list)/len(sim_list), 3)
66
+ print(f"\nTotal {len(sim_list)} samples")
67
+ print(f"SIM : {sim}")
scripts/eval_seedtts_testset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Seed-TTS testset
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_seed_tts_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "zh" # zh | en
18
+ metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
+ # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
+ gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
+
22
+
23
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
25
+ gpus = [0,1,2,3,4,5,6,7]
26
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
+
28
+ local = False
29
+ if local: # use local custom checkpoint dir
30
+ if lang == "zh":
31
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
+ elif lang == "en":
33
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
+ else:
35
+ asr_ckpt_dir = "" # auto download to cache dir
36
+
37
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
38
+
39
+
40
+ # --------------------------- WER ---------------------------
41
+
42
+ if eval_task == "wer":
43
+ wers = []
44
+
45
+ with mp.Pool(processes=len(gpus)) as pool:
46
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
47
+ results = pool.map(run_asr_wer, args)
48
+ for wers_ in results:
49
+ wers.extend(wers_)
50
+
51
+ wer = round(np.mean(wers)*100, 3)
52
+ print(f"\nTotal {len(wers)} samples")
53
+ print(f"WER : {wer}%")
54
+
55
+
56
+ # --------------------------- SIM ---------------------------
57
+
58
+ if eval_task == "sim":
59
+ sim_list = []
60
+
61
+ with mp.Pool(processes=len(gpus)) as pool:
62
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
63
+ results = pool.map(run_sim, args)
64
+ for sim_ in results:
65
+ sim_list.extend(sim_)
66
+
67
+ sim = round(sum(sim_list)/len(sim_list), 3)
68
+ print(f"\nTotal {len(sim_list)} samples")
69
+ print(f"SIM : {sim}")
scripts/prepare_emilia.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
2
+ # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
3
+
4
+ # generate audio text map for Emilia ZH & EN
5
+ # evaluate for vocab size
6
+
7
+ import sys, os
8
+ sys.path.append(os.getcwd())
9
+
10
+ from pathlib import Path
11
+ import json
12
+ from tqdm import tqdm
13
+ from concurrent.futures import ProcessPoolExecutor
14
+
15
+ from datasets import Dataset
16
+ from datasets.arrow_writer import ArrowWriter
17
+
18
+ from model.utils import (
19
+ repetition_found,
20
+ convert_char_to_pinyin,
21
+ )
22
+
23
+
24
+ out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
25
+ zh_filters = ["い", "て"]
26
+ # seems synthesized audios, or heavily code-switched
27
+ out_en = {
28
+ "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
+
30
+ "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
31
+ }
32
+ en_filters = ["ا", "い", "て"]
33
+
34
+
35
+ def deal_with_audio_dir(audio_dir):
36
+ audio_jsonl = audio_dir.with_suffix(".jsonl")
37
+ sub_result, durations = [], []
38
+ vocab_set = set()
39
+ bad_case_zh = 0
40
+ bad_case_en = 0
41
+ with open(audio_jsonl, "r") as f:
42
+ lines = f.readlines()
43
+ for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
+ obj = json.loads(line)
45
+ text = obj["text"]
46
+ if obj['language'] == "zh":
47
+ if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
+ bad_case_zh += 1
49
+ continue
50
+ else:
51
+ text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
+ if obj['language'] == "en":
53
+ if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
54
+ bad_case_en += 1
55
+ continue
56
+ if tokenizer == "pinyin":
57
+ text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
+ duration = obj["duration"]
59
+ sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
+ durations.append(duration)
61
+ vocab_set.update(list(text))
62
+ return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
63
+
64
+
65
+ def main():
66
+ assert tokenizer in ["pinyin", "char"]
67
+ result = []
68
+ duration_list = []
69
+ text_vocab_set = set()
70
+ total_bad_case_zh = 0
71
+ total_bad_case_en = 0
72
+
73
+ # process raw data
74
+ executor = ProcessPoolExecutor(max_workers=max_workers)
75
+ futures = []
76
+ for lang in langs:
77
+ dataset_path = Path(os.path.join(dataset_dir, lang))
78
+ [
79
+ futures.append(executor.submit(deal_with_audio_dir, audio_dir))
80
+ for audio_dir in dataset_path.iterdir()
81
+ if audio_dir.is_dir()
82
+ ]
83
+ for futures in tqdm(futures, total=len(futures)):
84
+ sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
85
+ result.extend(sub_result)
86
+ duration_list.extend(durations)
87
+ text_vocab_set.update(vocab_set)
88
+ total_bad_case_zh += bad_case_zh
89
+ total_bad_case_en += bad_case_en
90
+ executor.shutdown()
91
+
92
+ # save preprocessed dataset to disk
93
+ if not os.path.exists(f"data/{dataset_name}"):
94
+ os.makedirs(f"data/{dataset_name}")
95
+ print(f"\nSaving to data/{dataset_name} ...")
96
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
+ with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
+ writer.write(line)
101
+
102
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
103
+ with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
+
106
+ # vocab map, i.e. tokenizer
107
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
108
+ # if tokenizer == "pinyin":
109
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
110
+ with open(f"data/{dataset_name}/vocab.txt", "w") as f:
111
+ for vocab in sorted(text_vocab_set):
112
+ f.write(vocab + "\n")
113
+
114
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
+ if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
+ if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+
123
+ max_workers = 32
124
+
125
+ tokenizer = "pinyin" # "pinyin" | "char"
126
+ polyphone = True
127
+
128
+ langs = ["ZH", "EN"]
129
+ dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
130
+ dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
131
+ print(f"\nPrepare for {dataset_name}\n")
132
+
133
+ main()
134
+
135
+ # Emilia ZH & EN
136
+ # samples count 37837916 (after removal)
137
+ # pinyin vocab size 2543 (polyphone)
138
+ # total duration 95281.87 (hours)
139
+ # bad zh asr cnt 230435 (samples)
140
+ # bad eh asr cnt 37217 (samples)
141
+
142
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
143
+ # please be careful if using pretrained model, make sure the vocab.txt is same
scripts/prepare_wenetspeech4tts.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate audio text map for WenetSpeech4TTS
2
+ # evaluate for vocab size
3
+
4
+ import sys, os
5
+ sys.path.append(os.getcwd())
6
+
7
+ import json
8
+ from tqdm import tqdm
9
+ from concurrent.futures import ProcessPoolExecutor
10
+
11
+ import torchaudio
12
+ from datasets import Dataset
13
+
14
+ from model.utils import convert_char_to_pinyin
15
+
16
+
17
+ def deal_with_sub_path_files(dataset_path, sub_path):
18
+ print(f"Dealing with: {sub_path}")
19
+
20
+ text_dir = os.path.join(dataset_path, sub_path, "txts")
21
+ audio_dir = os.path.join(dataset_path, sub_path, "wavs")
22
+ text_files = os.listdir(text_dir)
23
+
24
+ audio_paths, texts, durations = [], [], []
25
+ for text_file in tqdm(text_files):
26
+ with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
+ first_line = file.readline().split("\t")
28
+ audio_nm = first_line[0]
29
+ audio_path = os.path.join(audio_dir, audio_nm + ".wav")
30
+ text = first_line[1].strip()
31
+
32
+ audio_paths.append(audio_path)
33
+
34
+ if tokenizer == "pinyin":
35
+ texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
+ elif tokenizer == "char":
37
+ texts.append(text)
38
+
39
+ audio, sample_rate = torchaudio.load(audio_path)
40
+ durations.append(audio.shape[-1] / sample_rate)
41
+
42
+ return audio_paths, texts, durations
43
+
44
+
45
+ def main():
46
+ assert tokenizer in ["pinyin", "char"]
47
+
48
+ audio_path_list, text_list, duration_list = [], [], []
49
+
50
+ executor = ProcessPoolExecutor(max_workers=max_workers)
51
+ futures = []
52
+ for dataset_path in dataset_paths:
53
+ sub_items = os.listdir(dataset_path)
54
+ sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
55
+ for sub_path in sub_paths:
56
+ futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
57
+ for future in tqdm(futures, total=len(futures)):
58
+ audio_paths, texts, durations = future.result()
59
+ audio_path_list.extend(audio_paths)
60
+ text_list.extend(texts)
61
+ duration_list.extend(durations)
62
+ executor.shutdown()
63
+
64
+ if not os.path.exists("data"):
65
+ os.makedirs("data")
66
+
67
+ print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
68
+ dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
+ dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
+
71
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
+ json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
73
+
74
+ print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
+ text_vocab_set = set()
76
+ for text in tqdm(text_list):
77
+ text_vocab_set.update(list(text))
78
+
79
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
80
+ if tokenizer == "pinyin":
81
+ text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
82
+
83
+ with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
84
+ for vocab in sorted(text_vocab_set):
85
+ f.write(vocab + "\n")
86
+ print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
+
89
+
90
+ if __name__ == "__main__":
91
+
92
+ max_workers = 32
93
+
94
+ tokenizer = "pinyin" # "pinyin" | "char"
95
+ polyphone = True
96
+ dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
+
98
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
+ dataset_paths = [
100
+ "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
+ "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
+ "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
+ ][-dataset_choice:]
104
+ print(f"\nChoose Dataset: {dataset_name}\n")
105
+
106
+ main()
107
+
108
+ # Results (if adding alphabets with accents and symbols):
109
+ # WenetSpeech4TTS Basic Standard Premium
110
+ # samples count 3932473 1941220 407494
111
+ # pinyin vocab size 1349 1348 1344 (no polyphone)
112
+ # - - 1459 (polyphone)
113
+ # char vocab size 5264 5219 5042
114
+
115
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
+ # please be careful if using pretrained model, make sure the vocab.txt is same
test_infer_batch.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ from tqdm import tqdm
5
+ import argparse
6
+
7
+ import torch
8
+ import torchaudio
9
+ from accelerate import Accelerator
10
+ from einops import rearrange
11
+ from vocos import Vocos
12
+
13
+ from model import CFM, UNetT, DiT
14
+ from model.utils import (
15
+ load_checkpoint,
16
+ get_tokenizer,
17
+ get_seedtts_testset_metainfo,
18
+ get_librispeech_test_clean_metainfo,
19
+ get_inference_prompt,
20
+ )
21
+
22
+ accelerator = Accelerator()
23
+ device = f"cuda:{accelerator.process_index}"
24
+
25
+
26
+ # --------------------- Dataset Settings -------------------- #
27
+
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ target_rms = 0.1
32
+
33
+ tokenizer = "pinyin"
34
+
35
+
36
+ # ---------------------- infer setting ---------------------- #
37
+
38
+ parser = argparse.ArgumentParser(description="batch inference")
39
+
40
+ parser.add_argument('-s', '--seed', default=None, type=int)
41
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
42
+ parser.add_argument('-n', '--expname', required=True)
43
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
44
+
45
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
46
+ parser.add_argument('-o', '--odemethod', default="euler")
47
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
48
+
49
+ parser.add_argument('-t', '--testset', required=True)
50
+
51
+ args = parser.parse_args()
52
+
53
+
54
+ seed = args.seed
55
+ dataset_name = args.dataset
56
+ exp_name = args.expname
57
+ ckpt_step = args.ckptstep
58
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
59
+
60
+ nfe_step = args.nfestep
61
+ ode_method = args.odemethod
62
+ sway_sampling_coef = args.swaysampling
63
+
64
+ testset = args.testset
65
+
66
+
67
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
68
+ cfg_strength = 2.
69
+ speed = 1.
70
+ use_truth_duration = False
71
+ no_ref_audio = False
72
+
73
+
74
+ if exp_name == "F5TTS_Base":
75
+ model_cls = DiT
76
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
77
+
78
+ elif exp_name == "E2TTS_Base":
79
+ model_cls = UNetT
80
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
81
+
82
+
83
+ if testset == "ls_pc_test_clean":
84
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
85
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
86
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
87
+
88
+ elif testset == "seedtts_test_zh":
89
+ metalst = "data/seedtts_testset/zh/meta.lst"
90
+ metainfo = get_seedtts_testset_metainfo(metalst)
91
+
92
+ elif testset == "seedtts_test_en":
93
+ metalst = "data/seedtts_testset/en/meta.lst"
94
+ metainfo = get_seedtts_testset_metainfo(metalst)
95
+
96
+
97
+ # path to save genereted wavs
98
+ if seed is None: seed = random.randint(-10000, 10000)
99
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
100
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
101
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
102
+ f"_cfg{cfg_strength}_speed{speed}" \
103
+ f"{'_gt-dur' if use_truth_duration else ''}" \
104
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
105
+
106
+
107
+ # -------------------------------------------------#
108
+
109
+ use_ema = True
110
+
111
+ prompts_all = get_inference_prompt(
112
+ metainfo,
113
+ speed = speed,
114
+ tokenizer = tokenizer,
115
+ target_sample_rate = target_sample_rate,
116
+ n_mel_channels = n_mel_channels,
117
+ hop_length = hop_length,
118
+ target_rms = target_rms,
119
+ use_truth_duration = use_truth_duration,
120
+ infer_batch_size = infer_batch_size,
121
+ )
122
+
123
+ # Vocoder model
124
+ local = False
125
+ if local:
126
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
127
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
128
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
129
+ vocos.load_state_dict(state_dict)
130
+ vocos.eval()
131
+ else:
132
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
133
+
134
+ # Tokenizer
135
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
136
+
137
+ # Model
138
+ model = CFM(
139
+ transformer = model_cls(
140
+ **model_cfg,
141
+ text_num_embeds = vocab_size,
142
+ mel_dim = n_mel_channels
143
+ ),
144
+ mel_spec_kwargs = dict(
145
+ target_sample_rate = target_sample_rate,
146
+ n_mel_channels = n_mel_channels,
147
+ hop_length = hop_length,
148
+ ),
149
+ odeint_kwargs = dict(
150
+ method = ode_method,
151
+ ),
152
+ vocab_char_map = vocab_char_map,
153
+ ).to(device)
154
+
155
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
156
+
157
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
158
+ os.makedirs(output_dir)
159
+
160
+ # start batch inference
161
+ accelerator.wait_for_everyone()
162
+ start = time.time()
163
+
164
+ with accelerator.split_between_processes(prompts_all) as prompts:
165
+
166
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
167
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
168
+ ref_mels = ref_mels.to(device)
169
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
170
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
171
+
172
+ # Inference
173
+ with torch.inference_mode():
174
+ generated, _ = model.sample(
175
+ cond = ref_mels,
176
+ text = final_text_list,
177
+ duration = total_mel_lens,
178
+ lens = ref_mel_lens,
179
+ steps = nfe_step,
180
+ cfg_strength = cfg_strength,
181
+ sway_sampling_coef = sway_sampling_coef,
182
+ no_ref_audio = no_ref_audio,
183
+ seed = seed,
184
+ )
185
+ # Final result
186
+ for i, gen in enumerate(generated):
187
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
188
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
189
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
190
+ if ref_rms_list[i] < target_rms:
191
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
192
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
193
+
194
+ accelerator.wait_for_everyone()
195
+ if accelerator.is_main_process:
196
+ timediff = time.time() - start
197
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
test_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
test_infer_single.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from vocos import Vocos
8
+
9
+ from model import CFM, UNetT, DiT, MMDiT
10
+ from model.utils import (
11
+ load_checkpoint,
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 0.8
43
+ fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
44
+
45
+ if exp_name == "F5TTS_Base":
46
+ model_cls = DiT
47
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
48
+
49
+ elif exp_name == "E2TTS_Base":
50
+ model_cls = UNetT
51
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
52
+
53
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
54
+ output_dir = "tests"
55
+
56
+ ref_audio = "tests/ref_audio/rashmika_input.wav"
57
+ ref_text = ""
58
+
59
+ # gen_text = "Hi everyone! This is Rashmika Mandanna, sending you my heartfelt Diwali wishes on behalf of Antriksh AI. May this festival of lights fill your life with love, joy, and togetherness. Let’s celebrate the spirit of Diwali by spreading happiness and peace wherever we go. At Antriksh AI, we’re excited to blend innovation with creativity, and this Diwali, we wish you endless light and positivity. Enjoy the festivities and cherish every moment! With love, Rashmika Mandanna & Antriksh AI"
60
+
61
+ gen_text_ = "Happy Birthday, Dhillip Kumar. Virat Kohli this side, all the best for your future endeavours! "
62
+
63
+
64
+ # ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
65
+ # ref_text = "对,这就是我,万人敬仰的太乙真人。"
66
+ # gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
67
+
68
+
69
+ # -------------------------------------------------#
70
+
71
+ use_ema = True
72
+
73
+ if not os.path.exists(output_dir):
74
+ os.makedirs(output_dir)
75
+
76
+ # Vocoder model
77
+ local = False
78
+ if local:
79
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
80
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
81
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
82
+ vocos.load_state_dict(state_dict)
83
+ vocos.eval()
84
+ else:
85
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
86
+
87
+ # Tokenizer
88
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
89
+
90
+ # Model
91
+ model = CFM(
92
+ transformer = model_cls(
93
+ **model_cfg,
94
+ text_num_embeds = vocab_size,
95
+ mel_dim = n_mel_channels
96
+ ),
97
+ mel_spec_kwargs = dict(
98
+ target_sample_rate = target_sample_rate,
99
+ n_mel_channels = n_mel_channels,
100
+ hop_length = hop_length,
101
+ ),
102
+ odeint_kwargs = dict(
103
+ method = ode_method,
104
+ ),
105
+ vocab_char_map = vocab_char_map,
106
+ ).to(device)
107
+
108
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
109
+
110
+ # Audio
111
+ audio, sr = torchaudio.load(ref_audio)
112
+ if audio.shape[0] > 1:
113
+ audio = torch.mean(audio, dim=0, keepdim=True)
114
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
115
+ if rms < target_rms:
116
+ audio = audio * target_rms / rms
117
+ if sr != target_sample_rate:
118
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
119
+ audio = resampler(audio)
120
+ audio = audio.to(device)
121
+
122
+ # Text
123
+ text_list = [ref_text + gen_text]
124
+ if tokenizer == "pinyin":
125
+ final_text_list = convert_char_to_pinyin(text_list)
126
+ else:
127
+ final_text_list = [text_list]
128
+ print(f"text : {text_list}")
129
+ print(f"pinyin: {final_text_list}")
130
+
131
+ # Duration
132
+ ref_audio_len = audio.shape[-1] // hop_length
133
+ if fix_duration is not None:
134
+ duration = int(fix_duration * target_sample_rate / hop_length)
135
+ else: # simple linear scale calcul
136
+ zh_pause_punc = r"。,、;:?!"
137
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
138
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
139
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
140
+
141
+ # Inference
142
+ with torch.inference_mode():
143
+ generated, trajectory = model.sample(
144
+ cond = audio,
145
+ text = final_text_list,
146
+ duration = duration,
147
+ steps = nfe_step,
148
+ cfg_strength = cfg_strength,
149
+ sway_sampling_coef = sway_sampling_coef,
150
+ seed = seed,
151
+ )
152
+ print(f"Generated mel: {generated.shape}")
153
+
154
+ # Final result
155
+ generated = generated[:, ref_audio_len:, :]
156
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
157
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
158
+ if rms < target_rms:
159
+ generated_wave = generated_wave * rms / target_rms
160
+
161
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_dbday.png")
162
+ torchaudio.save(f"{output_dir}/test_single_dbday.wav", generated_wave, target_sample_rate)
163
+ print(f"Generated wav: {generated_wave.shape}")
test_infer_single_edit.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from vocos import Vocos
8
+
9
+ from model import CFM, UNetT, DiT, MMDiT
10
+ from model.utils import (
11
+ load_checkpoint,
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+
44
+ if exp_name == "F5TTS_Base":
45
+ model_cls = DiT
46
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
47
+
48
+ elif exp_name == "E2TTS_Base":
49
+ model_cls = UNetT
50
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
+
52
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
53
+ output_dir = "tests"
54
+
55
+ # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
56
+ # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
57
+ # [write the origin_text into a file, e.g. tests/test_edit.txt]
58
+ # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
59
+ # [result will be saved at same path of audio file]
60
+ # [--language "zho" for Chinese, "eng" for English]
61
+ # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
62
+
63
+ audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
64
+ origin_text = "Some call me nature, others call me mother nature."
65
+ target_text = "Some call me optimist, others call me realist."
66
+ parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
67
+ fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
68
+
69
+ # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
70
+ # origin_text = "对,这就是我,万人敬仰的太乙真人。"
71
+ # target_text = "对,那就是你,万人敬仰的太白金星。"
72
+ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
73
+ # fix_duration = None # use origin text duration
74
+
75
+
76
+ # -------------------------------------------------#
77
+
78
+ use_ema = True
79
+
80
+ if not os.path.exists(output_dir):
81
+ os.makedirs(output_dir)
82
+
83
+ # Vocoder model
84
+ local = False
85
+ if local:
86
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
89
+ vocos.load_state_dict(state_dict)
90
+ vocos.eval()
91
+ else:
92
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
93
+
94
+ # Tokenizer
95
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96
+
97
+ # Model
98
+ model = CFM(
99
+ transformer = model_cls(
100
+ **model_cfg,
101
+ text_num_embeds = vocab_size,
102
+ mel_dim = n_mel_channels
103
+ ),
104
+ mel_spec_kwargs = dict(
105
+ target_sample_rate = target_sample_rate,
106
+ n_mel_channels = n_mel_channels,
107
+ hop_length = hop_length,
108
+ ),
109
+ odeint_kwargs = dict(
110
+ method = ode_method,
111
+ ),
112
+ vocab_char_map = vocab_char_map,
113
+ ).to(device)
114
+
115
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
116
+
117
+ # Audio
118
+ audio, sr = torchaudio.load(audio_to_edit)
119
+ if audio.shape[0] > 1:
120
+ audio = torch.mean(audio, dim=0, keepdim=True)
121
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
122
+ if rms < target_rms:
123
+ audio = audio * target_rms / rms
124
+ if sr != target_sample_rate:
125
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
126
+ audio = resampler(audio)
127
+ offset = 0
128
+ audio_ = torch.zeros(1, 0)
129
+ edit_mask = torch.zeros(1, 0, dtype=torch.bool)
130
+ for part in parts_to_edit:
131
+ start, end = part
132
+ part_dur = end - start if fix_duration is None else fix_duration.pop(0)
133
+ part_dur = part_dur * target_sample_rate
134
+ start = start * target_sample_rate
135
+ audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
136
+ edit_mask = torch.cat((edit_mask,
137
+ torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
138
+ torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
139
+ ), dim = -1)
140
+ offset = end * target_sample_rate
141
+ # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
142
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
143
+ audio = audio.to(device)
144
+ edit_mask = edit_mask.to(device)
145
+
146
+ # Text
147
+ text_list = [target_text]
148
+ if tokenizer == "pinyin":
149
+ final_text_list = convert_char_to_pinyin(text_list)
150
+ else:
151
+ final_text_list = [text_list]
152
+ print(f"text : {text_list}")
153
+ print(f"pinyin: {final_text_list}")
154
+
155
+ # Duration
156
+ ref_audio_len = 0
157
+ duration = audio.shape[-1] // hop_length
158
+
159
+ # Inference
160
+ with torch.inference_mode():
161
+ generated, trajectory = model.sample(
162
+ cond = audio,
163
+ text = final_text_list,
164
+ duration = duration,
165
+ steps = nfe_step,
166
+ cfg_strength = cfg_strength,
167
+ sway_sampling_coef = sway_sampling_coef,
168
+ seed = seed,
169
+ edit_mask = edit_mask,
170
+ )
171
+ print(f"Generated mel: {generated.shape}")
172
+
173
+ # Final result
174
+ generated = generated[:, ref_audio_len:, :]
175
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
176
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
177
+ if rms < target_rms:
178
+ generated_wave = generated_wave * rms / target_rms
179
+
180
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
181
+ torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
182
+ print(f"Generated wav: {generated_wave.shape}")
test_train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
2
+ from model.utils import get_tokenizer
3
+ from model.dataset import load_dataset
4
+
5
+
6
+ # -------------------------- Dataset Settings --------------------------- #
7
+
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin"
13
+ dataset_name = "Emilia_ZH_EN"
14
+
15
+
16
+ # -------------------------- Training Settings -------------------------- #
17
+
18
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
19
+
20
+ learning_rate = 7.5e-5
21
+
22
+ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
+ batch_size_type = "frame" # "frame" or "sample"
24
+ max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
+ grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.
27
+
28
+ epochs = 11 # use linear decay, thus epochs control the slope
29
+ num_warmup_updates = 20000 # warmup steps
30
+ save_per_updates = 50000 # save checkpoint per steps
31
+ last_per_steps = 5000 # save last checkpoint per steps
32
+
33
+ # model params
34
+ if exp_name == "F5TTS_Base":
35
+ wandb_resume_id = None
36
+ model_cls = DiT
37
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
38
+ elif exp_name == "E2TTS_Base":
39
+ wandb_resume_id = None
40
+ model_cls = UNetT
41
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
42
+
43
+
44
+ # ----------------------------------------------------------------------- #
45
+
46
+ def main():
47
+
48
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
49
+
50
+ mel_spec_kwargs = dict(
51
+ target_sample_rate = target_sample_rate,
52
+ n_mel_channels = n_mel_channels,
53
+ hop_length = hop_length,
54
+ )
55
+
56
+ e2tts = CFM(
57
+ transformer = model_cls(
58
+ **model_cfg,
59
+ text_num_embeds = vocab_size,
60
+ mel_dim = n_mel_channels
61
+ ),
62
+ mel_spec_kwargs = mel_spec_kwargs,
63
+ vocab_char_map = vocab_char_map,
64
+ )
65
+
66
+ trainer = Trainer(
67
+ e2tts,
68
+ epochs,
69
+ learning_rate,
70
+ num_warmup_updates = num_warmup_updates,
71
+ save_per_updates = save_per_updates,
72
+ checkpoint_path = f'ckpts/{exp_name}',
73
+ batch_size = batch_size_per_gpu,
74
+ batch_size_type = batch_size_type,
75
+ max_samples = max_samples,
76
+ grad_accumulation_steps = grad_accumulation_steps,
77
+ max_grad_norm = max_grad_norm,
78
+ wandb_project = "CFM-TTS",
79
+ wandb_run_name = exp_name,
80
+ wandb_resume_id = wandb_resume_id,
81
+ last_per_steps = last_per_steps,
82
+ )
83
+
84
+ train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
+ trainer.train(train_dataset,
86
+ resumable_with_seed = 666 # seed for shuffling dataset
87
+ )
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()