Hecheng0625 commited on
Commit
c968fc3
1 Parent(s): 8c92a11

Upload 409 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. README.md +169 -14
  3. models/__init__.py +0 -0
  4. models/base/__init__.py +7 -0
  5. models/base/base_dataset.py +464 -0
  6. models/base/base_inference.py +220 -0
  7. models/base/base_sampler.py +157 -0
  8. models/base/base_trainer.py +348 -0
  9. models/base/new_dataset.py +50 -0
  10. models/base/new_inference.py +253 -0
  11. models/base/new_trainer.py +727 -0
  12. models/codec/__init__.py +0 -0
  13. models/codec/amphion_codec/codec.py +427 -0
  14. models/codec/amphion_codec/quantize/__init__.py +11 -0
  15. models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  16. models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  17. models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  18. models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  19. models/codec/amphion_codec/vocos.py +881 -0
  20. models/codec/codec_dataset.py +264 -0
  21. models/codec/codec_inference.py +515 -0
  22. models/codec/codec_sampler.py +126 -0
  23. models/codec/codec_trainer.py +166 -0
  24. models/codec/facodec/__init__.py +0 -0
  25. models/codec/facodec/alias_free_torch/__init__.py +5 -0
  26. models/codec/facodec/alias_free_torch/act.py +29 -0
  27. models/codec/facodec/alias_free_torch/filter.py +96 -0
  28. models/codec/facodec/alias_free_torch/resample.py +57 -0
  29. models/codec/facodec/facodec_dataset.py +98 -0
  30. models/codec/facodec/facodec_inference.py +137 -0
  31. models/codec/facodec/facodec_trainer.py +776 -0
  32. models/codec/facodec/modules/JDC/__init__.py +1 -0
  33. models/codec/facodec/modules/JDC/bst.t7 +3 -0
  34. models/codec/facodec/modules/JDC/model.py +219 -0
  35. models/codec/facodec/modules/attentions.py +437 -0
  36. models/codec/facodec/modules/commons.py +331 -0
  37. models/codec/facodec/modules/gradient_reversal.py +35 -0
  38. models/codec/facodec/modules/layers.py +460 -0
  39. models/codec/facodec/modules/quantize.py +741 -0
  40. models/codec/facodec/modules/style_encoder.py +110 -0
  41. models/codec/facodec/modules/wavenet.py +224 -0
  42. models/codec/facodec/optimizer.py +104 -0
  43. models/codec/kmeans/repcodec_model.py +210 -0
  44. models/codec/kmeans/vocos.py +850 -0
  45. models/codec/ns3_codec/README.md +216 -0
  46. models/codec/ns3_codec/__init__.py +6 -0
  47. models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  48. models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  49. models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  50. models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
.gitattributes CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  imgs/vocoder/gan/MSSBCQTD.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  imgs/vocoder/gan/MSSBCQTD.png filter=lfs diff=lfs merge=lfs -text
37
+ models/codec/facodec/modules/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
38
+ models/tts/maskgct/g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
39
+ models/tts/maskgct/wav/prompt.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,169 @@
1
- ---
2
- title: Maskgct
3
- emoji: 🚀
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: MaskGCT TTS Demo
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Amphion: An Open-Source Audio, Music, and Speech Generation Toolkit
2
+
3
+ <div>
4
+ <a href="https://arxiv.org/abs/2312.09911"><img src="https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg"></a>
5
+ <a href="https://huggingface.co/amphion"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink"></a>
6
+ <a href="https://openxlab.org.cn/usercenter/Amphion"><img src="https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg"></a>
7
+ <a href="https://discord.com/invite/ZxxREr3Y"><img src="https://img.shields.io/badge/Discord-Join%20chat-blue.svg"></a>
8
+ <a href="egs/tts/README.md"><img src="https://img.shields.io/badge/README-TTS-blue"></a>
9
+ <a href="egs/svc/README.md"><img src="https://img.shields.io/badge/README-SVC-blue"></a>
10
+ <a href="egs/tta/README.md"><img src="https://img.shields.io/badge/README-TTA-blue"></a>
11
+ <a href="egs/vocoder/README.md"><img src="https://img.shields.io/badge/README-Vocoder-purple"></a>
12
+ <a href="egs/metrics/README.md"><img src="https://img.shields.io/badge/README-Evaluation-yellow"></a>
13
+ <a href="LICENSE"><img src="https://img.shields.io/badge/LICENSE-MIT-red"></a>
14
+ </div>
15
+ <br>
16
+
17
+ **Amphion (/æmˈfaɪən/) is a toolkit for Audio, Music, and Speech Generation.** Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development. Amphion offers a unique feature: **visualizations** of classic models or architectures. We believe that these visualizations are beneficial for junior researchers and engineers who wish to gain a better understanding of the model.
18
+
19
+ **The North-Star objective of Amphion is to offer a platform for studying the conversion of any inputs into audio.** Amphion is designed to support individual generation tasks, including but not limited to,
20
+
21
+ - **TTS**: Text to Speech (⛳ supported)
22
+ - **SVS**: Singing Voice Synthesis (👨‍💻 developing)
23
+ - **VC**: Voice Conversion (👨‍💻 developing)
24
+ - **SVC**: Singing Voice Conversion (⛳ supported)
25
+ - **TTA**: Text to Audio (⛳ supported)
26
+ - **TTM**: Text to Music (👨‍💻 developing)
27
+ - more…
28
+
29
+ In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis.
30
+
31
+ ## 🚀 News
32
+ - **2024/10/19**: We release **MaskGCT**, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision. MaskGCT is trained on Emilia dataset and achieves SOTA zero-shot TTS perfermance. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/tts/maskgct/README.md)
33
+ - **2024/09/01**: [Amphion](https://arxiv.org/abs/2312.09911), [Emilia](https://arxiv.org/abs/2407.05361) and [DSFF-SVC](https://arxiv.org/abs/2310.11160) got accepted by IEEE SLT 2024! 🤗
34
+ - **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.gg/drhW7ajqAG) to stay connected and engage with our community!
35
+ - **2024/08/20**: [SingVisio](https://arxiv.org/abs/2402.12660) got accepted by Computers & Graphics, [available here](https://www.sciencedirect.com/science/article/pii/S0097849324001936)! 🎉
36
+ - **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)! 👑👑👑
37
+ - **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](preprocessors/Emilia/README.md)
38
+ - **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md)
39
+ - **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md)
40
+ - **2024/02/22**: The first Amphion visualization tool, **SingVisio**, release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/visualization/SingVisio/README.md)
41
+ - **2023/12/18**: Amphion v0.1 release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2312.09911) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink)](https://huggingface.co/amphion) [![youtube](https://img.shields.io/badge/YouTube-Demo-red)](https://www.youtube.com/watch?v=1aw0HhcggvQ) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/39)
42
+ - **2023/11/28**: Amphion alpha release. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/2)
43
+
44
+ ## ⭐ Key Features
45
+
46
+ ### TTS: Text to Speech
47
+
48
+ - Amphion achieves state-of-the-art performance compared to existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures:
49
+ - [FastSpeech2](https://arxiv.org/abs/2006.04558): A non-autoregressive TTS architecture that utilizes feed-forward Transformer blocks.
50
+ - [VITS](https://arxiv.org/abs/2106.06103): An end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning
51
+ - [VALL-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes.
52
+ - [NaturalSpeech2](https://arxiv.org/abs/2304.09116): An architecture for TTS that utilizes a latent diffusion model to generate natural-sounding voices.
53
+ - [Jets](Jets): An end-to-end TTS model that jointly trains FastSpeech2 and HiFi-GAN with an alignment module.
54
+ - [MaskGCT](https://arxiv.org/abs/2409.00750): a fully non-autoregressive TTS architecture that eliminates the need for explicit alignment information between text and speech supervision.
55
+
56
+ ### SVC: Singing Voice Conversion
57
+
58
+ - Ampion supports multiple content-based features from various pretrained models, including [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec). Their specific roles in SVC has been investigated in our SLT 2024 paper. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2310.11160) [![code](https://img.shields.io/badge/README-Code-red)](egs/svc/MultipleContentsSVC)
59
+ - Amphion implements several state-of-the-art model architectures, including diffusion-, transformer-, VAE- and flow-based models. The diffusion-based architecture uses [Bidirectional dilated CNN](https://openreview.net/pdf?id=a-xFK8Ymz5J) as a backend and supports several sampling algorithms such as [DDPM](https://arxiv.org/pdf/2006.11239.pdf), [DDIM](https://arxiv.org/pdf/2010.02502.pdf), and [PNDM](https://arxiv.org/pdf/2202.09778.pdf). Additionally, it supports single-step inference based on the [Consistency Model](https://openreview.net/pdf?id=FmqFfMTNnv).
60
+
61
+ ### TTA: Text to Audio
62
+
63
+ - Amphion supports the TTA with a latent diffusion model. It is designed like [AudioLDM](https://arxiv.org/abs/2301.12503), [Make-an-Audio](https://arxiv.org/abs/2301.12661), and [AUDIT](https://arxiv.org/abs/2304.00830). It is also the official implementation of the text-to-audio generation part of our NeurIPS 2023 paper. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2304.00830) [![code](https://img.shields.io/badge/README-Code-red)](egs/tta/RECIPE.md)
64
+
65
+ ### Vocoder
66
+
67
+ - Amphion supports various widely-used neural vocoders, including:
68
+ - GAN-based vocoders: [MelGAN](https://arxiv.org/abs/1910.06711), [HiFi-GAN](https://arxiv.org/abs/2010.05646), [NSF-HiFiGAN](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts), [BigVGAN](https://arxiv.org/abs/2206.04658), [APNet](https://arxiv.org/abs/2305.07952).
69
+ - Flow-based vocoders: [WaveGlow](https://arxiv.org/abs/1811.00002).
70
+ - Diffusion-based vocoders: [Diffwave](https://arxiv.org/abs/2009.09761).
71
+ - Auto-regressive based vocoders: [WaveNet](https://arxiv.org/abs/1609.03499), [WaveRNN](https://arxiv.org/abs/1802.08435v1).
72
+ - Amphion provides the official implementation of [Multi-Scale Constant-Q Transform Discriminator](https://arxiv.org/abs/2311.14957) (our ICASSP 2024 paper). It can be used to enhance any architecture GAN-based vocoders during training, and keep the inference stage (such as memory or speed) unchanged. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2311.14957) [![code](https://img.shields.io/badge/README-Code-red)](egs/vocoder/gan/tfr_enhanced_hifigan)
73
+
74
+ ### Evaluation
75
+
76
+ Amphion provides a comprehensive objective evaluation of the generated audio. The evaluation metrics contain:
77
+
78
+ - **F0 Modeling**: F0 Pearson Coefficients, F0 Periodicity Root Mean Square Error, F0 Root Mean Square Error, Voiced/Unvoiced F1 Score, etc.
79
+ - **Energy Modeling**: Energy Root Mean Square Error, Energy Pearson Coefficients, etc.
80
+ - **Intelligibility**: Character/Word Error Rate, which can be calculated based on [Whisper](https://github.com/openai/whisper) and more.
81
+ - **Spectrogram Distortion**: Frechet Audio Distance (FAD), Mel Cepstral Distortion (MCD), Multi-Resolution STFT Distance (MSTFT), Perceptual Evaluation of Speech Quality (PESQ), Short Time Objective Intelligibility (STOI), etc.
82
+ - **Speaker Similarity**: Cosine similarity, which can be calculated based on [RawNet3](https://github.com/Jungjee/RawNet), [Resemblyzer](https://github.com/resemble-ai/Resemblyzer), [WeSpeaker](https://github.com/wenet-e2e/wespeaker), [WavLM](https://github.com/microsoft/unilm/tree/master/wavlm) and more.
83
+
84
+ ### Datasets
85
+
86
+ - Amphion unifies the data preprocess of the open-source datasets including [AudioCaps](https://audiocaps.github.io/), [LibriTTS](https://www.openslr.org/60/), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [M4Singer](https://github.com/M4Singer/M4Singer), [Opencpop](https://wenet.org.cn/opencpop/), [OpenSinger](https://github.com/Multi-Singer/Multi-Singer.github.io), [SVCC](http://vc-challenge.org/), [VCTK](https://datashare.ed.ac.uk/handle/10283/3443), and more. The supported dataset list can be seen [here](egs/datasets/README.md) (updating).
87
+ - Amphion (exclusively) supports the [**Emilia**](preprocessors/Emilia/README.md) dataset and its preprocessing pipeline **Emilia-Pipe** for in-the-wild speech data!
88
+
89
+ ### Visualization
90
+
91
+ Amphion provides visualization tools to interactively illustrate the internal processing mechanism of classic models. This provides an invaluable resource for educational purposes and for facilitating understandable research.
92
+
93
+ Currently, Amphion supports [SingVisio](egs/visualization/SingVisio/README.md), a visualization tool of the diffusion model for singing voice conversion. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view)
94
+
95
+
96
+ ## 📀 Installation
97
+
98
+ Amphion can be installed through either Setup Installer or Docker Image.
99
+
100
+ ### Setup Installer
101
+
102
+ ```bash
103
+ git clone https://github.com/open-mmlab/Amphion.git
104
+ cd Amphion
105
+
106
+ # Install Python Environment
107
+ conda create --name amphion python=3.9.15
108
+ conda activate amphion
109
+
110
+ # Install Python Packages Dependencies
111
+ sh env.sh
112
+ ```
113
+
114
+ ### Docker Image
115
+
116
+ 1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads).
117
+
118
+ 2. Run the following commands:
119
+ ```bash
120
+ git clone https://github.com/open-mmlab/Amphion.git
121
+ cd Amphion
122
+
123
+ docker pull realamphion/amphion
124
+ docker run --runtime=nvidia --gpus all -it -v .:/app realamphion/amphion
125
+ ```
126
+ Mount dataset by argument `-v` is necessary when using Docker. Please refer to [Mount dataset in Docker container](egs/datasets/docker.md) and [Docker Docs](https://docs.docker.com/engine/reference/commandline/container_run/#volume) for more details.
127
+
128
+
129
+ ## 🐍 Usage in Python
130
+
131
+ We detail the instructions of different tasks in the following recipes:
132
+
133
+ - [Text to Speech (TTS)](egs/tts/README.md)
134
+ - [Singing Voice Conversion (SVC)](egs/svc/README.md)
135
+ - [Text to Audio (TTA)](egs/tta/README.md)
136
+ - [Vocoder](egs/vocoder/README.md)
137
+ - [Evaluation](egs/metrics/README.md)
138
+ - [Visualization](egs/visualization/README.md)
139
+
140
+ ## 👨‍💻 Contributing
141
+ We appreciate all contributions to improve Amphion. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
142
+
143
+ ## 🙏 Acknowledgement
144
+
145
+
146
+ - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) and [jaywalnut310's VITS](https://github.com/jaywalnut310/vits) for model architecture code.
147
+ - [lifeiteng's VALL-E](https://github.com/lifeiteng/vall-e) for training pipeline and model architecture design.
148
+ - [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) for semantic-distilled tokenizer design.
149
+ - [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), [ContentVec](https://github.com/auspicious3000/contentvec), and [RawNet3](https://github.com/Jungjee/RawNet) for pretrained models and inference code.
150
+ - [HiFi-GAN](https://github.com/jik876/hifi-gan) for GAN-based Vocoder's architecture design and training strategy.
151
+ - [Encodec](https://github.com/facebookresearch/encodec) for well-organized GAN Discriminator's architecture and basic blocks.
152
+ - [Latent Diffusion](https://github.com/CompVis/latent-diffusion) for model architecture design.
153
+ - [TensorFlowTTS](https://github.com/TensorSpeech/TensorFlowTTS) for preparing the MFA tools.
154
+
155
+
156
+ ## ©️ License
157
+
158
+ Amphion is under the [MIT License](LICENSE). It is free for both research and commercial use cases.
159
+
160
+ ## 📚 Citations
161
+
162
+ ```bibtex
163
+ @inproceedings{amphion,
164
+ author={Zhang, Xueyao and Xue, Liumeng and Gu, Yicheng and Wang, Yuancheng and Li, Jiaqi and He, Haorui and Wang, Chaoren and Song, Ting and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zhang, Junan and Tang, Tze Ying and Zou, Lexiao and Wang, Mingxuan and Han, Jun and Chen, Kai and Li, Haizhou and Wu, Zhizheng},
165
+ title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
166
+ booktitle={{IEEE} Spoken Language Technology Workshop, {SLT} 2024},
167
+ year={2024}
168
+ }
169
+ ```
models/__init__.py ADDED
File without changes
models/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .new_trainer import BaseTrainer
7
+ from .new_inference import BaseInference
models/base/base_dataset.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torch.utils.data
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ import librosa
11
+
12
+ from utils.data_utils import *
13
+ from processors.acoustic_extractor import cal_normalized_mel
14
+ from text import text_to_sequence
15
+ from text.text_token_collation import phoneIDCollation
16
+
17
+
18
+ class BaseOfflineDataset(torch.utils.data.Dataset):
19
+ def __init__(self, cfg, dataset, is_valid=False):
20
+ """
21
+ Args:
22
+ cfg: config
23
+ dataset: dataset name
24
+ is_valid: whether to use train or valid dataset
25
+ """
26
+
27
+ assert isinstance(dataset, str)
28
+
29
+ # self.data_root = processed_data_dir
30
+ self.cfg = cfg
31
+
32
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
33
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
34
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
35
+ self.metadata = self.get_metadata()
36
+
37
+ """
38
+ load spk2id and utt2spk from json file
39
+ spk2id: {spk1: 0, spk2: 1, ...}
40
+ utt2spk: {dataset_uid: spk1, ...}
41
+ """
42
+ if cfg.preprocess.use_spkid:
43
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
44
+ with open(spk2id_path, "r") as f:
45
+ self.spk2id = json.load(f)
46
+
47
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
48
+ self.utt2spk = dict()
49
+ with open(utt2spk_path, "r") as f:
50
+ for line in f.readlines():
51
+ utt, spk = line.strip().split("\t")
52
+ self.utt2spk[utt] = spk
53
+
54
+ if cfg.preprocess.use_uv:
55
+ self.utt2uv_path = {}
56
+ for utt_info in self.metadata:
57
+ dataset = utt_info["Dataset"]
58
+ uid = utt_info["Uid"]
59
+ utt = "{}_{}".format(dataset, uid)
60
+ self.utt2uv_path[utt] = os.path.join(
61
+ cfg.preprocess.processed_dir,
62
+ dataset,
63
+ cfg.preprocess.uv_dir,
64
+ uid + ".npy",
65
+ )
66
+
67
+ if cfg.preprocess.use_frame_pitch:
68
+ self.utt2frame_pitch_path = {}
69
+ for utt_info in self.metadata:
70
+ dataset = utt_info["Dataset"]
71
+ uid = utt_info["Uid"]
72
+ utt = "{}_{}".format(dataset, uid)
73
+
74
+ self.utt2frame_pitch_path[utt] = os.path.join(
75
+ cfg.preprocess.processed_dir,
76
+ dataset,
77
+ cfg.preprocess.pitch_dir,
78
+ uid + ".npy",
79
+ )
80
+
81
+ if cfg.preprocess.use_frame_energy:
82
+ self.utt2frame_energy_path = {}
83
+ for utt_info in self.metadata:
84
+ dataset = utt_info["Dataset"]
85
+ uid = utt_info["Uid"]
86
+ utt = "{}_{}".format(dataset, uid)
87
+
88
+ self.utt2frame_energy_path[utt] = os.path.join(
89
+ cfg.preprocess.processed_dir,
90
+ dataset,
91
+ cfg.preprocess.energy_dir,
92
+ uid + ".npy",
93
+ )
94
+
95
+ if cfg.preprocess.use_mel:
96
+ self.utt2mel_path = {}
97
+ for utt_info in self.metadata:
98
+ dataset = utt_info["Dataset"]
99
+ uid = utt_info["Uid"]
100
+ utt = "{}_{}".format(dataset, uid)
101
+
102
+ self.utt2mel_path[utt] = os.path.join(
103
+ cfg.preprocess.processed_dir,
104
+ dataset,
105
+ cfg.preprocess.mel_dir,
106
+ uid + ".npy",
107
+ )
108
+
109
+ if cfg.preprocess.use_linear:
110
+ self.utt2linear_path = {}
111
+ for utt_info in self.metadata:
112
+ dataset = utt_info["Dataset"]
113
+ uid = utt_info["Uid"]
114
+ utt = "{}_{}".format(dataset, uid)
115
+
116
+ self.utt2linear_path[utt] = os.path.join(
117
+ cfg.preprocess.processed_dir,
118
+ dataset,
119
+ cfg.preprocess.linear_dir,
120
+ uid + ".npy",
121
+ )
122
+
123
+ if cfg.preprocess.use_audio:
124
+ self.utt2audio_path = {}
125
+ for utt_info in self.metadata:
126
+ dataset = utt_info["Dataset"]
127
+ uid = utt_info["Uid"]
128
+ utt = "{}_{}".format(dataset, uid)
129
+
130
+ self.utt2audio_path[utt] = os.path.join(
131
+ cfg.preprocess.processed_dir,
132
+ dataset,
133
+ cfg.preprocess.audio_dir,
134
+ uid + ".npy",
135
+ )
136
+ elif cfg.preprocess.use_label:
137
+ self.utt2label_path = {}
138
+ for utt_info in self.metadata:
139
+ dataset = utt_info["Dataset"]
140
+ uid = utt_info["Uid"]
141
+ utt = "{}_{}".format(dataset, uid)
142
+
143
+ self.utt2label_path[utt] = os.path.join(
144
+ cfg.preprocess.processed_dir,
145
+ dataset,
146
+ cfg.preprocess.label_dir,
147
+ uid + ".npy",
148
+ )
149
+ elif cfg.preprocess.use_one_hot:
150
+ self.utt2one_hot_path = {}
151
+ for utt_info in self.metadata:
152
+ dataset = utt_info["Dataset"]
153
+ uid = utt_info["Uid"]
154
+ utt = "{}_{}".format(dataset, uid)
155
+
156
+ self.utt2one_hot_path[utt] = os.path.join(
157
+ cfg.preprocess.processed_dir,
158
+ dataset,
159
+ cfg.preprocess.one_hot_dir,
160
+ uid + ".npy",
161
+ )
162
+
163
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
164
+ self.utt2seq = {}
165
+ for utt_info in self.metadata:
166
+ dataset = utt_info["Dataset"]
167
+ uid = utt_info["Uid"]
168
+ utt = "{}_{}".format(dataset, uid)
169
+
170
+ if cfg.preprocess.use_text:
171
+ text = utt_info["Text"]
172
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
173
+ elif cfg.preprocess.use_phone:
174
+ # load phoneme squence from phone file
175
+ phone_path = os.path.join(
176
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
177
+ )
178
+ with open(phone_path, "r") as fin:
179
+ phones = fin.readlines()
180
+ assert len(phones) == 1
181
+ phones = phones[0].strip()
182
+ phones_seq = phones.split(" ")
183
+
184
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
185
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
186
+
187
+ self.utt2seq[utt] = sequence
188
+
189
+ def get_metadata(self):
190
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
191
+ metadata = json.load(f)
192
+
193
+ return metadata
194
+
195
+ def get_dataset_name(self):
196
+ return self.metadata[0]["Dataset"]
197
+
198
+ def __getitem__(self, index):
199
+ utt_info = self.metadata[index]
200
+
201
+ dataset = utt_info["Dataset"]
202
+ uid = utt_info["Uid"]
203
+ utt = "{}_{}".format(dataset, uid)
204
+
205
+ single_feature = dict()
206
+
207
+ if self.cfg.preprocess.use_spkid:
208
+ single_feature["spk_id"] = np.array(
209
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
210
+ )
211
+
212
+ if self.cfg.preprocess.use_mel:
213
+ mel = np.load(self.utt2mel_path[utt])
214
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
215
+ if self.cfg.preprocess.use_min_max_norm_mel:
216
+ # do mel norm
217
+ mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
218
+
219
+ if "target_len" not in single_feature.keys():
220
+ single_feature["target_len"] = mel.shape[1]
221
+ single_feature["mel"] = mel.T # [T, n_mels]
222
+
223
+ if self.cfg.preprocess.use_linear:
224
+ linear = np.load(self.utt2linear_path[utt])
225
+ if "target_len" not in single_feature.keys():
226
+ single_feature["target_len"] = linear.shape[1]
227
+ single_feature["linear"] = linear.T # [T, n_linear]
228
+
229
+ if self.cfg.preprocess.use_frame_pitch:
230
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
231
+ frame_pitch = np.load(frame_pitch_path)
232
+ if "target_len" not in single_feature.keys():
233
+ single_feature["target_len"] = len(frame_pitch)
234
+ aligned_frame_pitch = align_length(
235
+ frame_pitch, single_feature["target_len"]
236
+ )
237
+ single_feature["frame_pitch"] = aligned_frame_pitch
238
+
239
+ if self.cfg.preprocess.use_uv:
240
+ frame_uv_path = self.utt2uv_path[utt]
241
+ frame_uv = np.load(frame_uv_path)
242
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
243
+ aligned_frame_uv = [
244
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
245
+ ]
246
+ aligned_frame_uv = np.array(aligned_frame_uv)
247
+ single_feature["frame_uv"] = aligned_frame_uv
248
+
249
+ if self.cfg.preprocess.use_frame_energy:
250
+ frame_energy_path = self.utt2frame_energy_path[utt]
251
+ frame_energy = np.load(frame_energy_path)
252
+ if "target_len" not in single_feature.keys():
253
+ single_feature["target_len"] = len(frame_energy)
254
+ aligned_frame_energy = align_length(
255
+ frame_energy, single_feature["target_len"]
256
+ )
257
+ single_feature["frame_energy"] = aligned_frame_energy
258
+
259
+ if self.cfg.preprocess.use_audio:
260
+ audio = np.load(self.utt2audio_path[utt])
261
+ single_feature["audio"] = audio
262
+ single_feature["audio_len"] = audio.shape[0]
263
+
264
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
265
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
266
+ single_feature["phone_len"] = len(self.utt2seq[utt])
267
+
268
+ return single_feature
269
+
270
+ def __len__(self):
271
+ return len(self.metadata)
272
+
273
+
274
+ class BaseOfflineCollator(object):
275
+ """Zero-pads model inputs and targets based on number of frames per step"""
276
+
277
+ def __init__(self, cfg):
278
+ self.cfg = cfg
279
+
280
+ def __call__(self, batch):
281
+ packed_batch_features = dict()
282
+
283
+ # mel: [b, T, n_mels]
284
+ # frame_pitch, frame_energy: [1, T]
285
+ # target_len: [b]
286
+ # spk_id: [b, 1]
287
+ # mask: [b, T, 1]
288
+
289
+ for key in batch[0].keys():
290
+ if key == "target_len":
291
+ packed_batch_features["target_len"] = torch.LongTensor(
292
+ [b["target_len"] for b in batch]
293
+ )
294
+ masks = [
295
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
296
+ ]
297
+ packed_batch_features["mask"] = pad_sequence(
298
+ masks, batch_first=True, padding_value=0
299
+ )
300
+ elif key == "phone_len":
301
+ packed_batch_features["phone_len"] = torch.LongTensor(
302
+ [b["phone_len"] for b in batch]
303
+ )
304
+ masks = [
305
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
306
+ ]
307
+ packed_batch_features["phn_mask"] = pad_sequence(
308
+ masks, batch_first=True, padding_value=0
309
+ )
310
+ elif key == "audio_len":
311
+ packed_batch_features["audio_len"] = torch.LongTensor(
312
+ [b["audio_len"] for b in batch]
313
+ )
314
+ masks = [
315
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
316
+ ]
317
+ else:
318
+ values = [torch.from_numpy(b[key]) for b in batch]
319
+ packed_batch_features[key] = pad_sequence(
320
+ values, batch_first=True, padding_value=0
321
+ )
322
+ return packed_batch_features
323
+
324
+
325
+ class BaseOnlineDataset(torch.utils.data.Dataset):
326
+ def __init__(self, cfg, dataset, is_valid=False):
327
+ """
328
+ Args:
329
+ cfg: config
330
+ dataset: dataset name
331
+ is_valid: whether to use train or valid dataset
332
+ """
333
+ assert isinstance(dataset, str)
334
+
335
+ self.cfg = cfg
336
+ self.sample_rate = cfg.preprocess.sample_rate
337
+ self.hop_size = self.cfg.preprocess.hop_size
338
+
339
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
340
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
341
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
342
+ self.metadata = self.get_metadata()
343
+
344
+ """
345
+ load spk2id and utt2spk from json file
346
+ spk2id: {spk1: 0, spk2: 1, ...}
347
+ utt2spk: {dataset_uid: spk1, ...}
348
+ """
349
+ if cfg.preprocess.use_spkid:
350
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
351
+ with open(spk2id_path, "r") as f:
352
+ self.spk2id = json.load(f)
353
+
354
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
355
+ self.utt2spk = dict()
356
+ with open(utt2spk_path, "r") as f:
357
+ for line in f.readlines():
358
+ utt, spk = line.strip().split("\t")
359
+ self.utt2spk[utt] = spk
360
+
361
+ def get_metadata(self):
362
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
363
+ metadata = json.load(f)
364
+
365
+ return metadata
366
+
367
+ def get_dataset_name(self):
368
+ return self.metadata[0]["Dataset"]
369
+
370
+ def __getitem__(self, index):
371
+ """
372
+ single_feature:
373
+ wav: (T)
374
+ wav_len: int
375
+ target_len: int
376
+ mask: (n_frames, 1)
377
+ spk_id: (1)
378
+ """
379
+ utt_item = self.metadata[index]
380
+
381
+ wav_path = utt_item["Path"]
382
+ wav, _ = librosa.load(wav_path, sr=self.sample_rate)
383
+ # wav: (T)
384
+ wav = torch.as_tensor(wav, dtype=torch.float32)
385
+ wav_len = len(wav)
386
+ # mask: (n_frames, 1)
387
+ frame_len = wav_len // self.hop_size
388
+ mask = torch.ones(frame_len, 1, dtype=torch.long)
389
+
390
+ single_feature = {
391
+ "wav": wav,
392
+ "wav_len": wav_len,
393
+ "target_len": frame_len,
394
+ "mask": mask,
395
+ }
396
+
397
+ if self.cfg.preprocess.use_spkid:
398
+ utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
399
+ single_feature["spk_id"] = torch.tensor(
400
+ [self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
401
+ )
402
+
403
+ return single_feature
404
+
405
+ def __len__(self):
406
+ return len(self.metadata)
407
+
408
+
409
+ class BaseOnlineCollator(object):
410
+ """Zero-pads model inputs and targets based on number of frames per step (For on-the-fly features extraction, whose iterative item contains only wavs)"""
411
+
412
+ def __init__(self, cfg):
413
+ self.cfg = cfg
414
+
415
+ def __call__(self, batch):
416
+ """
417
+ BaseOnlineDataset.__getitem__:
418
+ wav: (T,)
419
+ wav_len: int
420
+ target_len: int
421
+ mask: (n_frames, 1)
422
+ spk_id: (1)
423
+
424
+ Returns:
425
+ wav: (B, T), torch.float32
426
+ wav_len: (B), torch.long
427
+ target_len: (B), torch.long
428
+ mask: (B, n_frames, 1), torch.long
429
+ spk_id: (B, 1), torch.int32
430
+ """
431
+ packed_batch_features = dict()
432
+
433
+ for key in batch[0].keys():
434
+ if key in ["wav_len", "target_len"]:
435
+ packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
436
+ else:
437
+ packed_batch_features[key] = pad_sequence(
438
+ [b[key] for b in batch], batch_first=True, padding_value=0
439
+ )
440
+ return packed_batch_features
441
+
442
+
443
+ class BaseTestDataset(torch.utils.data.Dataset):
444
+ def __init__(self, cfg, args):
445
+ raise NotImplementedError
446
+
447
+ def get_metadata(self):
448
+ raise NotImplementedError
449
+
450
+ def __getitem__(self, index):
451
+ raise NotImplementedError
452
+
453
+ def __len__(self):
454
+ return len(self.metadata)
455
+
456
+
457
+ class BaseTestCollator(object):
458
+ """Zero-pads model inputs and targets based on number of frames per step"""
459
+
460
+ def __init__(self, cfg):
461
+ raise NotImplementedError
462
+
463
+ def __call__(self, batch):
464
+ raise NotImplementedError
models/base/base_inference.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import os
8
+ import re
9
+ import time
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+
16
+ from models.vocoders.vocoder_inference import synthesis
17
+ from torch.utils.data import DataLoader
18
+ from utils.util import set_all_random_seed
19
+ from utils.util import load_config
20
+
21
+
22
+ def parse_vocoder(vocoder_dir):
23
+ r"""Parse vocoder config"""
24
+ vocoder_dir = os.path.abspath(vocoder_dir)
25
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
26
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
27
+ ckpt_path = str(ckpt_list[0])
28
+ vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
29
+ vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
30
+ return vocoder_cfg, ckpt_path
31
+
32
+
33
+ class BaseInference(object):
34
+ def __init__(self, cfg, args):
35
+ self.cfg = cfg
36
+ self.args = args
37
+ self.model_type = cfg.model_type
38
+ self.avg_rtf = list()
39
+ set_all_random_seed(10086)
40
+ os.makedirs(args.output_dir, exist_ok=True)
41
+
42
+ if torch.cuda.is_available():
43
+ self.device = torch.device("cuda")
44
+ else:
45
+ self.device = torch.device("cpu")
46
+ torch.set_num_threads(10) # inference on 1 core cpu.
47
+
48
+ # Load acoustic model
49
+ self.model = self.create_model().to(self.device)
50
+ state_dict = self.load_state_dict()
51
+ self.load_model(state_dict)
52
+ self.model.eval()
53
+
54
+ # Load vocoder model if necessary
55
+ if self.args.checkpoint_dir_vocoder is not None:
56
+ self.get_vocoder_info()
57
+
58
+ def create_model(self):
59
+ raise NotImplementedError
60
+
61
+ def load_state_dict(self):
62
+ self.checkpoint_file = self.args.checkpoint_file
63
+ if self.checkpoint_file is None:
64
+ assert self.args.checkpoint_dir is not None
65
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
66
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
67
+ self.checkpoint_file = os.path.join(
68
+ self.args.checkpoint_dir, checkpoint_filename
69
+ )
70
+
71
+ self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
72
+
73
+ print("Restore acoustic model from {}".format(self.checkpoint_file))
74
+ raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
75
+ self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
76
+
77
+ return raw_state_dict
78
+
79
+ def load_model(self, model):
80
+ raise NotImplementedError
81
+
82
+ def get_vocoder_info(self):
83
+ self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
84
+ self.vocoder_cfg = os.path.join(
85
+ os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
86
+ )
87
+ self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
88
+ self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
89
+ self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
90
+
91
+ def build_test_utt_data(self):
92
+ raise NotImplementedError
93
+
94
+ def build_testdata_loader(self, args, target_speaker=None):
95
+ datasets, collate = self.build_test_dataset()
96
+ self.test_dataset = datasets(self.cfg, args, target_speaker)
97
+ self.test_collate = collate(self.cfg)
98
+ self.test_batch_size = min(
99
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
100
+ )
101
+ test_loader = DataLoader(
102
+ self.test_dataset,
103
+ collate_fn=self.test_collate,
104
+ num_workers=self.args.num_workers,
105
+ batch_size=self.test_batch_size,
106
+ shuffle=False,
107
+ )
108
+ return test_loader
109
+
110
+ def inference_each_batch(self, batch_data):
111
+ raise NotImplementedError
112
+
113
+ def inference_for_batches(self, args, target_speaker=None):
114
+ ###### Construct test_batch ######
115
+ loader = self.build_testdata_loader(args, target_speaker)
116
+
117
+ n_batch = len(loader)
118
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
119
+ print(
120
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
121
+ now, self.test_batch_size, n_batch
122
+ )
123
+ )
124
+ self.model.eval()
125
+
126
+ ###### Inference for each batch ######
127
+ pred_res = []
128
+ with torch.no_grad():
129
+ for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
130
+ # Put the data to device
131
+ for k, v in batch_data.items():
132
+ batch_data[k] = batch_data[k].to(self.device)
133
+
134
+ y_pred, stats = self.inference_each_batch(batch_data)
135
+
136
+ pred_res += y_pred
137
+
138
+ return pred_res
139
+
140
+ def inference(self, feature):
141
+ raise NotImplementedError
142
+
143
+ def synthesis_by_vocoder(self, pred):
144
+ audios_pred = synthesis(
145
+ self.vocoder_cfg,
146
+ self.checkpoint_dir_vocoder,
147
+ len(pred),
148
+ pred,
149
+ )
150
+ return audios_pred
151
+
152
+ def __call__(self, utt):
153
+ feature = self.build_test_utt_data(utt)
154
+ start_time = time.time()
155
+ with torch.no_grad():
156
+ outputs = self.inference(feature)[0]
157
+ time_used = time.time() - start_time
158
+ rtf = time_used / (
159
+ outputs.shape[1]
160
+ * self.cfg.preprocess.hop_size
161
+ / self.cfg.preprocess.sample_rate
162
+ )
163
+ print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
164
+ self.avg_rtf.append(rtf)
165
+ audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
166
+ return audios
167
+
168
+
169
+ def base_parser():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument(
172
+ "--config", default="config.json", help="json files for configurations."
173
+ )
174
+ parser.add_argument("--use_ddp_inference", default=False)
175
+ parser.add_argument("--n_workers", default=1, type=int)
176
+ parser.add_argument("--local_rank", default=-1, type=int)
177
+ parser.add_argument(
178
+ "--batch_size", default=1, type=int, help="Batch size for inference"
179
+ )
180
+ parser.add_argument(
181
+ "--num_workers",
182
+ default=1,
183
+ type=int,
184
+ help="Worker number for inference dataloader",
185
+ )
186
+ parser.add_argument(
187
+ "--checkpoint_dir",
188
+ type=str,
189
+ default=None,
190
+ help="Checkpoint dir including model file and configuration",
191
+ )
192
+ parser.add_argument(
193
+ "--checkpoint_file", help="checkpoint file", type=str, default=None
194
+ )
195
+ parser.add_argument(
196
+ "--test_list", help="test utterance list for testing", type=str, default=None
197
+ )
198
+ parser.add_argument(
199
+ "--checkpoint_dir_vocoder",
200
+ help="Vocoder's checkpoint dir including model file and configuration",
201
+ type=str,
202
+ default=None,
203
+ )
204
+ parser.add_argument(
205
+ "--output_dir",
206
+ type=str,
207
+ default=None,
208
+ help="Output dir for saving generated results",
209
+ )
210
+ return parser
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = base_parser()
215
+ args = parser.parse_args()
216
+ cfg = load_config(args.config)
217
+
218
+ # Build inference
219
+ inference = BaseInference(cfg, args)
220
+ inference()
models/base/base_sampler.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import random
8
+
9
+ from torch.utils.data import ConcatDataset, Dataset
10
+ from torch.utils.data.sampler import (
11
+ BatchSampler,
12
+ RandomSampler,
13
+ Sampler,
14
+ SequentialSampler,
15
+ )
16
+
17
+
18
+ class ScheduledSampler(Sampler):
19
+ """A sampler that samples data from a given concat-dataset.
20
+
21
+ Args:
22
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
23
+ batch_size (int): batch size
24
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
25
+ logger (logging.Logger): logger to print warning message
26
+
27
+ Usage:
28
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
29
+ >>> list(ScheduledSampler(ConcatDataset([[0, 1, 2], [3, 4, 5], [6, 7, 8]])))
30
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ concat_dataset,
36
+ batch_size,
37
+ holistic_shuffle,
38
+ logger=None,
39
+ loader_type="train",
40
+ ):
41
+ if not isinstance(concat_dataset, ConcatDataset):
42
+ raise ValueError(
43
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
44
+ type(concat_dataset)
45
+ )
46
+ )
47
+ if not isinstance(batch_size, int):
48
+ raise ValueError(
49
+ "batch_size must be an integer, but got {}".format(type(batch_size))
50
+ )
51
+ if not isinstance(holistic_shuffle, bool):
52
+ raise ValueError(
53
+ "holistic_shuffle must be a boolean, but got {}".format(
54
+ type(holistic_shuffle)
55
+ )
56
+ )
57
+
58
+ self.concat_dataset = concat_dataset
59
+ self.batch_size = batch_size
60
+ self.holistic_shuffle = holistic_shuffle
61
+
62
+ affected_dataset_name = []
63
+ affected_dataset_len = []
64
+ for dataset in concat_dataset.datasets:
65
+ dataset_len = len(dataset)
66
+ dataset_name = dataset.get_dataset_name()
67
+ if dataset_len < batch_size:
68
+ affected_dataset_name.append(dataset_name)
69
+ affected_dataset_len.append(dataset_len)
70
+
71
+ self.type = loader_type
72
+ for dataset_name, dataset_len in zip(
73
+ affected_dataset_name, affected_dataset_len
74
+ ):
75
+ if not loader_type == "valid":
76
+ logger.warning(
77
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
78
+ loader_type, dataset_name, dataset_len, batch_size
79
+ )
80
+ )
81
+
82
+ def __len__(self):
83
+ # the number of batches with drop last
84
+ num_of_batches = sum(
85
+ [
86
+ math.floor(len(dataset) / self.batch_size)
87
+ for dataset in self.concat_dataset.datasets
88
+ ]
89
+ )
90
+ # if samples are not enough for one batch, we don't drop last
91
+ if self.type == "valid" and num_of_batches < 1:
92
+ return len(self.concat_dataset)
93
+ return num_of_batches * self.batch_size
94
+
95
+ def __iter__(self):
96
+ iters = []
97
+ for dataset in self.concat_dataset.datasets:
98
+ iters.append(
99
+ SequentialSampler(dataset).__iter__()
100
+ if not self.holistic_shuffle
101
+ else RandomSampler(dataset).__iter__()
102
+ )
103
+ # e.g. [0, 200, 400]
104
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
105
+ output_batches = []
106
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
107
+ cur_batch = []
108
+ for idx in iters[dataset_idx]:
109
+ cur_batch.append(idx + init_indices[dataset_idx])
110
+ if len(cur_batch) == self.batch_size:
111
+ output_batches.append(cur_batch)
112
+ cur_batch = []
113
+ # if loader_type is valid, we don't need to drop last
114
+ if self.type == "valid" and len(cur_batch) > 0:
115
+ output_batches.append(cur_batch)
116
+
117
+ # force drop last in training
118
+ random.shuffle(output_batches)
119
+ output_indices = [item for sublist in output_batches for item in sublist]
120
+ return iter(output_indices)
121
+
122
+
123
+ def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
124
+ sampler = ScheduledSampler(
125
+ concat_dataset,
126
+ cfg.train.batch_size,
127
+ cfg.train.sampler.holistic_shuffle,
128
+ logger,
129
+ loader_type,
130
+ )
131
+ batch_sampler = BatchSampler(
132
+ sampler,
133
+ cfg.train.batch_size,
134
+ cfg.train.sampler.drop_last if not loader_type == "valid" else False,
135
+ )
136
+ return sampler, batch_sampler
137
+
138
+
139
+ class VariableSampler(BatchSampler):
140
+ def __init__(self, sampler, drop_last: bool, use_random_sampler=False):
141
+ self.data_list = sampler
142
+ if use_random_sampler:
143
+ self.sampler = RandomSampler(sampler)
144
+ else:
145
+ self.sampler = SequentialSampler(sampler)
146
+
147
+ super().__init__(self.sampler, 1, drop_last)
148
+
149
+ def __iter__(self):
150
+ for batch_ids in self.data_list:
151
+ yield batch_ids
152
+
153
+ def __len__(self):
154
+ if self.drop_last:
155
+ return len(self.sampler) // self.batch_size
156
+ else:
157
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
models/base/base_trainer.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from torch.utils.data import ConcatDataset, DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from models.base.base_sampler import BatchSampler
19
+ from utils.util import (
20
+ Logger,
21
+ remove_older_ckpt,
22
+ save_config,
23
+ set_all_random_seed,
24
+ ValueWindow,
25
+ )
26
+
27
+
28
+ class BaseTrainer(object):
29
+ def __init__(self, args, cfg):
30
+ self.args = args
31
+ self.log_dir = args.log_dir
32
+ self.cfg = cfg
33
+
34
+ self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
35
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
36
+ if not cfg.train.ddp or args.local_rank == 0:
37
+ self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
38
+ self.logger = self.build_logger()
39
+ self.time_window = ValueWindow(50)
40
+
41
+ self.step = 0
42
+ self.epoch = -1
43
+ self.max_epochs = self.cfg.train.epochs
44
+ self.max_steps = self.cfg.train.max_steps
45
+
46
+ # set random seed & init distributed training
47
+ set_all_random_seed(self.cfg.train.random_seed)
48
+ if cfg.train.ddp:
49
+ dist.init_process_group(backend="nccl")
50
+
51
+ if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
52
+ self.singers = self.build_singers_lut()
53
+
54
+ # setup data_loader
55
+ self.data_loader = self.build_data_loader()
56
+
57
+ # setup model & enable distributed training
58
+ self.model = self.build_model()
59
+ print(self.model)
60
+
61
+ if isinstance(self.model, dict):
62
+ for key, value in self.model.items():
63
+ value.cuda(self.args.local_rank)
64
+ if key == "PQMF":
65
+ continue
66
+ if cfg.train.ddp:
67
+ self.model[key] = DistributedDataParallel(
68
+ value, device_ids=[self.args.local_rank]
69
+ )
70
+ else:
71
+ self.model.cuda(self.args.local_rank)
72
+ if cfg.train.ddp:
73
+ self.model = DistributedDataParallel(
74
+ self.model, device_ids=[self.args.local_rank]
75
+ )
76
+
77
+ # create criterion
78
+ self.criterion = self.build_criterion()
79
+ if isinstance(self.criterion, dict):
80
+ for key, value in self.criterion.items():
81
+ self.criterion[key].cuda(args.local_rank)
82
+ else:
83
+ self.criterion.cuda(self.args.local_rank)
84
+
85
+ # optimizer
86
+ self.optimizer = self.build_optimizer()
87
+ self.scheduler = self.build_scheduler()
88
+
89
+ # save config file
90
+ self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
91
+
92
+ def build_logger(self):
93
+ log_file = os.path.join(self.checkpoint_dir, "train.log")
94
+ logger = Logger(log_file, level=self.args.log_level).logger
95
+
96
+ return logger
97
+
98
+ def build_dataset(self):
99
+ raise NotImplementedError
100
+
101
+ def build_data_loader(self):
102
+ Dataset, Collator = self.build_dataset()
103
+ # build dataset instance for each dataset and combine them by ConcatDataset
104
+ datasets_list = []
105
+ for dataset in self.cfg.dataset:
106
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
107
+ datasets_list.append(subdataset)
108
+ train_dataset = ConcatDataset(datasets_list)
109
+
110
+ train_collate = Collator(self.cfg)
111
+ # TODO: multi-GPU training
112
+ if self.cfg.train.ddp:
113
+ raise NotImplementedError("DDP is not supported yet.")
114
+
115
+ # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
116
+ batch_sampler = BatchSampler(
117
+ cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
118
+ )
119
+
120
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
121
+ train_loader = DataLoader(
122
+ train_dataset,
123
+ collate_fn=train_collate,
124
+ num_workers=self.args.num_workers,
125
+ batch_sampler=batch_sampler,
126
+ pin_memory=False,
127
+ )
128
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
129
+ datasets_list = []
130
+ for dataset in self.cfg.dataset:
131
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
132
+ datasets_list.append(subdataset)
133
+ valid_dataset = ConcatDataset(datasets_list)
134
+ valid_collate = Collator(self.cfg)
135
+ batch_sampler = BatchSampler(
136
+ cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
137
+ )
138
+ valid_loader = DataLoader(
139
+ valid_dataset,
140
+ collate_fn=valid_collate,
141
+ num_workers=1,
142
+ batch_sampler=batch_sampler,
143
+ )
144
+ else:
145
+ raise NotImplementedError("DDP is not supported yet.")
146
+ # valid_loader = None
147
+ data_loader = {"train": train_loader, "valid": valid_loader}
148
+ return data_loader
149
+
150
+ def build_singers_lut(self):
151
+ # combine singers
152
+ if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
153
+ singers = collections.OrderedDict()
154
+ else:
155
+ with open(
156
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
157
+ ) as singer_file:
158
+ singers = json.load(singer_file)
159
+ singer_count = len(singers)
160
+ for dataset in self.cfg.dataset:
161
+ singer_lut_path = os.path.join(
162
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
163
+ )
164
+ with open(singer_lut_path, "r") as singer_lut_path:
165
+ singer_lut = json.load(singer_lut_path)
166
+ for singer in singer_lut.keys():
167
+ if singer not in singers:
168
+ singers[singer] = singer_count
169
+ singer_count += 1
170
+ with open(
171
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
172
+ ) as singer_file:
173
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
174
+ print(
175
+ "singers have been dumped to {}".format(
176
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
177
+ )
178
+ )
179
+ return singers
180
+
181
+ def build_model(self):
182
+ raise NotImplementedError()
183
+
184
+ def build_optimizer(self):
185
+ raise NotImplementedError
186
+
187
+ def build_scheduler(self):
188
+ raise NotImplementedError()
189
+
190
+ def build_criterion(self):
191
+ raise NotImplementedError
192
+
193
+ def get_state_dict(self):
194
+ raise NotImplementedError
195
+
196
+ def save_config_file(self):
197
+ save_config(self.config_save_path, self.cfg)
198
+
199
+ # TODO, save without module.
200
+ def save_checkpoint(self, state_dict, saved_model_path):
201
+ torch.save(state_dict, saved_model_path)
202
+
203
+ def load_checkpoint(self):
204
+ checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
205
+ assert os.path.exists(checkpoint_path)
206
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
207
+ model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
208
+ assert os.path.exists(model_path)
209
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
210
+ self.logger.info(f"Re(store) from {model_path}")
211
+ checkpoint = torch.load(model_path, map_location="cpu")
212
+ return checkpoint
213
+
214
+ def load_model(self, checkpoint):
215
+ raise NotImplementedError
216
+
217
+ def restore(self):
218
+ checkpoint = self.load_checkpoint()
219
+ self.load_model(checkpoint)
220
+
221
+ def train_step(self, data):
222
+ raise NotImplementedError(
223
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
224
+ f"your sub-class of {self.__class__.__name__}. "
225
+ )
226
+
227
+ @torch.no_grad()
228
+ def eval_step(self):
229
+ raise NotImplementedError(
230
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
231
+ f"your sub-class of {self.__class__.__name__}. "
232
+ )
233
+
234
+ def write_summary(self, losses, stats):
235
+ raise NotImplementedError(
236
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
237
+ f"your sub-class of {self.__class__.__name__}. "
238
+ )
239
+
240
+ def write_valid_summary(self, losses, stats):
241
+ raise NotImplementedError(
242
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
243
+ f"your sub-class of {self.__class__.__name__}. "
244
+ )
245
+
246
+ def echo_log(self, losses, mode="Training"):
247
+ message = [
248
+ "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
249
+ mode, self.epoch + 1, self.step, self.time_window.average
250
+ )
251
+ ]
252
+
253
+ for key in sorted(losses.keys()):
254
+ if isinstance(losses[key], dict):
255
+ for k, v in losses[key].items():
256
+ message.append(
257
+ str(k).split("/")[-1] + "=" + str(round(float(v), 5))
258
+ )
259
+ else:
260
+ message.append(
261
+ str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
262
+ )
263
+ self.logger.info(", ".join(message))
264
+
265
+ def eval_epoch(self):
266
+ self.logger.info("Validation...")
267
+ valid_losses = {}
268
+ for i, batch_data in enumerate(self.data_loader["valid"]):
269
+ for k, v in batch_data.items():
270
+ if isinstance(v, torch.Tensor):
271
+ batch_data[k] = v.cuda()
272
+ valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
273
+ for key in valid_loss:
274
+ if key not in valid_losses:
275
+ valid_losses[key] = 0
276
+ valid_losses[key] += valid_loss[key]
277
+
278
+ # Add mel and audio to the Tensorboard
279
+ # Average loss
280
+ for key in valid_losses:
281
+ valid_losses[key] /= i + 1
282
+ self.echo_log(valid_losses, "Valid")
283
+ return valid_losses, valid_stats
284
+
285
+ def train_epoch(self):
286
+ for i, batch_data in enumerate(self.data_loader["train"]):
287
+ start_time = time.time()
288
+ # Put the data to cuda device
289
+ for k, v in batch_data.items():
290
+ if isinstance(v, torch.Tensor):
291
+ batch_data[k] = v.cuda(self.args.local_rank)
292
+
293
+ # Training step
294
+ train_losses, train_stats, total_loss = self.train_step(batch_data)
295
+ self.time_window.append(time.time() - start_time)
296
+
297
+ if self.args.local_rank == 0 or not self.cfg.train.ddp:
298
+ if self.step % self.args.stdout_interval == 0:
299
+ self.echo_log(train_losses, "Training")
300
+
301
+ if self.step % self.cfg.train.save_summary_steps == 0:
302
+ self.logger.info(f"Save summary as step {self.step}")
303
+ self.write_summary(train_losses, train_stats)
304
+
305
+ if (
306
+ self.step % self.cfg.train.save_checkpoints_steps == 0
307
+ and self.step != 0
308
+ ):
309
+ saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
310
+ self.step, total_loss
311
+ )
312
+ saved_model_path = os.path.join(
313
+ self.checkpoint_dir, saved_model_name
314
+ )
315
+ saved_state_dict = self.get_state_dict()
316
+ self.save_checkpoint(saved_state_dict, saved_model_path)
317
+ self.save_config_file()
318
+ # keep max n models
319
+ remove_older_ckpt(
320
+ saved_model_name,
321
+ self.checkpoint_dir,
322
+ max_to_keep=self.cfg.train.keep_checkpoint_max,
323
+ )
324
+
325
+ if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
326
+ if isinstance(self.model, dict):
327
+ for key in self.model.keys():
328
+ self.model[key].eval()
329
+ else:
330
+ self.model.eval()
331
+ # Evaluate one epoch and get average loss
332
+ valid_losses, valid_stats = self.eval_epoch()
333
+ if isinstance(self.model, dict):
334
+ for key in self.model.keys():
335
+ self.model[key].train()
336
+ else:
337
+ self.model.train()
338
+ # Write validation losses to summary.
339
+ self.write_valid_summary(valid_losses, valid_stats)
340
+ self.step += 1
341
+
342
+ def train(self):
343
+ for epoch in range(max(0, self.epoch), self.max_epochs):
344
+ self.train_epoch()
345
+ self.epoch += 1
346
+ if self.step > self.max_steps:
347
+ self.logger.info("Training finished!")
348
+ break
models/base/new_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ from abc import abstractmethod
9
+ from pathlib import Path
10
+
11
+ import json5
12
+ import torch
13
+ import yaml
14
+
15
+
16
+ # TODO: for training and validating
17
+ class BaseDataset(torch.utils.data.Dataset):
18
+ r"""Base dataset for training and validating."""
19
+
20
+ def __init__(self, args, cfg, is_valid=False):
21
+ pass
22
+
23
+
24
+ class BaseTestDataset(torch.utils.data.Dataset):
25
+ r"""Test dataset for inference."""
26
+
27
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
28
+ assert infer_type in ["from_dataset", "from_file"]
29
+
30
+ self.args = args
31
+ self.cfg = cfg
32
+ self.infer_type = infer_type
33
+
34
+ @abstractmethod
35
+ def __getitem__(self, index):
36
+ pass
37
+
38
+ def __len__(self):
39
+ return len(self.metadata)
40
+
41
+ def get_metadata(self):
42
+ path = Path(self.args.source)
43
+ if path.suffix == ".json" or path.suffix == ".jsonc":
44
+ metadata = json5.load(open(self.args.source, "r"))
45
+ elif path.suffix == ".yaml" or path.suffix == ".yml":
46
+ metadata = yaml.full_load(open(self.args.source, "r"))
47
+ else:
48
+ raise ValueError(f"Unsupported file type: {path.suffix}")
49
+
50
+ return metadata
models/base/new_inference.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from abc import abstractmethod
11
+ from pathlib import Path
12
+
13
+ import accelerate
14
+ import json5
15
+ import numpy as np
16
+ import torch
17
+ from accelerate.logging import get_logger
18
+ from torch.utils.data import DataLoader
19
+
20
+ from models.vocoders.vocoder_inference import synthesis
21
+ from utils.io import save_audio
22
+ from utils.util import load_config
23
+ from utils.audio_slicer import is_silence
24
+
25
+ EPS = 1.0e-12
26
+
27
+
28
+ class BaseInference(object):
29
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
30
+ super().__init__()
31
+
32
+ start = time.monotonic_ns()
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ assert infer_type in ["from_dataset", "from_file"]
37
+ self.infer_type = infer_type
38
+
39
+ # init with accelerate
40
+ self.accelerator = accelerate.Accelerator()
41
+ self.accelerator.wait_for_everyone()
42
+
43
+ # Use accelerate logger for distributed inference
44
+ with self.accelerator.main_process_first():
45
+ self.logger = get_logger("inference", log_level=args.log_level)
46
+
47
+ # Log some info
48
+ self.logger.info("=" * 56)
49
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
50
+ self.logger.info("=" * 56)
51
+ self.logger.info("\n")
52
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
53
+
54
+ self.acoustics_dir = args.acoustics_dir
55
+ self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
56
+ self.vocoder_dir = args.vocoder_dir
57
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
58
+ # should be in svc inferencer
59
+ # self.target_singer = args.target_singer
60
+ # self.logger.info(f"Target singers: {args.target_singer}")
61
+ # self.trans_key = args.trans_key
62
+ # self.logger.info(f"Trans key: {args.trans_key}")
63
+
64
+ os.makedirs(args.output_dir, exist_ok=True)
65
+
66
+ # set random seed
67
+ with self.accelerator.main_process_first():
68
+ start = time.monotonic_ns()
69
+ self._set_random_seed(self.cfg.train.random_seed)
70
+ end = time.monotonic_ns()
71
+ self.logger.debug(
72
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
73
+ )
74
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
75
+
76
+ # setup data_loader
77
+ with self.accelerator.main_process_first():
78
+ self.logger.info("Building dataset...")
79
+ start = time.monotonic_ns()
80
+ self.test_dataloader = self._build_dataloader()
81
+ end = time.monotonic_ns()
82
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
83
+
84
+ # setup model
85
+ with self.accelerator.main_process_first():
86
+ self.logger.info("Building model...")
87
+ start = time.monotonic_ns()
88
+ self.model = self._build_model()
89
+ end = time.monotonic_ns()
90
+ # self.logger.debug(self.model)
91
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
92
+
93
+ # init with accelerate
94
+ self.logger.info("Initializing accelerate...")
95
+ start = time.monotonic_ns()
96
+ self.accelerator = accelerate.Accelerator()
97
+ self.model = self.accelerator.prepare(self.model)
98
+ end = time.monotonic_ns()
99
+ self.accelerator.wait_for_everyone()
100
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
101
+
102
+ with self.accelerator.main_process_first():
103
+ self.logger.info("Loading checkpoint...")
104
+ start = time.monotonic_ns()
105
+ # TODO: Also, suppose only use latest one yet
106
+ self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
107
+ end = time.monotonic_ns()
108
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
109
+
110
+ self.model.eval()
111
+ self.accelerator.wait_for_everyone()
112
+
113
+ ### Abstract methods ###
114
+ @abstractmethod
115
+ def _build_test_dataset(self):
116
+ pass
117
+
118
+ @abstractmethod
119
+ def _build_model(self):
120
+ pass
121
+
122
+ @abstractmethod
123
+ @torch.inference_mode()
124
+ def _inference_each_batch(self, batch_data):
125
+ pass
126
+
127
+ ### Abstract methods end ###
128
+
129
+ @torch.inference_mode()
130
+ def inference(self):
131
+ for i, batch in enumerate(self.test_dataloader):
132
+ y_pred = self._inference_each_batch(batch).cpu()
133
+
134
+ # Judge whether the min-max normliazation is used
135
+ if self.cfg.preprocess.use_min_max_norm_mel:
136
+ mel_min, mel_max = self.test_dataset.target_mel_extrema
137
+ y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
138
+
139
+ y_ls = y_pred.chunk(self.test_batch_size)
140
+ tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
141
+ j = 0
142
+ for it, l in zip(y_ls, tgt_ls):
143
+ l = l.item()
144
+ it = it.squeeze(0)[:l]
145
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
146
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
147
+ j += 1
148
+
149
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
150
+
151
+ res = synthesis(
152
+ cfg=vocoder_cfg,
153
+ vocoder_weight_file=vocoder_ckpt,
154
+ n_samples=None,
155
+ pred=[
156
+ torch.load(
157
+ os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
158
+ ).numpy(force=True)
159
+ for i in self.test_dataset.metadata
160
+ ],
161
+ )
162
+
163
+ output_audio_files = []
164
+ for it, wav in zip(self.test_dataset.metadata, res):
165
+ uid = it["Uid"]
166
+ file = os.path.join(self.args.output_dir, f"{uid}.wav")
167
+ output_audio_files.append(file)
168
+
169
+ wav = wav.numpy(force=True)
170
+ save_audio(
171
+ file,
172
+ wav,
173
+ self.cfg.preprocess.sample_rate,
174
+ add_silence=False,
175
+ turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
176
+ )
177
+ os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
178
+
179
+ return sorted(output_audio_files)
180
+
181
+ # TODO: LEGACY CODE
182
+ def _build_dataloader(self):
183
+ datasets, collate = self._build_test_dataset()
184
+ self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
185
+ self.test_collate = collate(self.cfg)
186
+ self.test_batch_size = min(
187
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
188
+ )
189
+ test_dataloader = DataLoader(
190
+ self.test_dataset,
191
+ collate_fn=self.test_collate,
192
+ num_workers=1,
193
+ batch_size=self.test_batch_size,
194
+ shuffle=False,
195
+ )
196
+ return test_dataloader
197
+
198
+ def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
199
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
200
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
201
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
202
+ method after** ``accelerator.prepare()``.
203
+ """
204
+ if checkpoint_path is None:
205
+ ls = []
206
+ for i in Path(checkpoint_dir).iterdir():
207
+ if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
208
+ ls.append(i)
209
+ ls.sort(
210
+ key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
211
+ )
212
+ checkpoint_path = ls[0]
213
+ else:
214
+ checkpoint_path = Path(checkpoint_path)
215
+ self.accelerator.load_state(str(checkpoint_path))
216
+ # set epoch and step
217
+ self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
218
+ self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
219
+ return str(checkpoint_path)
220
+
221
+ @staticmethod
222
+ def _set_random_seed(seed):
223
+ r"""Set random seed for all possible random modules."""
224
+ random.seed(seed)
225
+ np.random.seed(seed)
226
+ torch.random.manual_seed(seed)
227
+
228
+ @staticmethod
229
+ def _parse_vocoder(vocoder_dir):
230
+ r"""Parse vocoder config"""
231
+ vocoder_dir = os.path.abspath(vocoder_dir)
232
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
233
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
234
+ ckpt_path = str(ckpt_list[0])
235
+ vocoder_cfg = load_config(
236
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
237
+ )
238
+ return vocoder_cfg, ckpt_path
239
+
240
+ @staticmethod
241
+ def __count_parameters(model):
242
+ return sum(p.numel() for p in model.parameters())
243
+
244
+ def __dump_cfg(self, path):
245
+ os.makedirs(os.path.dirname(path), exist_ok=True)
246
+ json5.dump(
247
+ self.cfg,
248
+ open(path, "w"),
249
+ indent=4,
250
+ sort_keys=True,
251
+ ensure_ascii=False,
252
+ quote_keys=True,
253
+ )
models/base/new_trainer.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import random
9
+ import shutil
10
+ import time
11
+ from abc import abstractmethod
12
+ from pathlib import Path
13
+
14
+ import accelerate
15
+ import json5
16
+ import numpy as np
17
+ import torch
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration
20
+ from torch.utils.data import ConcatDataset, DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from models.base.base_sampler import build_samplers
24
+ from optimizer.optimizers import NoamLR
25
+
26
+
27
+ class BaseTrainer(object):
28
+ r"""The base trainer for all tasks. Any trainer should inherit from this class."""
29
+
30
+ def __init__(self, args=None, cfg=None):
31
+ super().__init__()
32
+
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ cfg.exp_name = args.exp_name
37
+
38
+ # init with accelerate
39
+ self._init_accelerator()
40
+ self.accelerator.wait_for_everyone()
41
+
42
+ # Use accelerate logger for distributed training
43
+ with self.accelerator.main_process_first():
44
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
45
+
46
+ # Log some info
47
+ self.logger.info("=" * 56)
48
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
49
+ self.logger.info("=" * 56)
50
+ self.logger.info("\n")
51
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
52
+ self.logger.info(f"Experiment name: {args.exp_name}")
53
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
54
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
55
+ if self.accelerator.is_main_process:
56
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
57
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
58
+
59
+ # init counts
60
+ self.batch_count: int = 0
61
+ self.step: int = 0
62
+ self.epoch: int = 0
63
+ self.max_epoch = (
64
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
65
+ )
66
+ self.logger.info(
67
+ "Max epoch: {}".format(
68
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
69
+ )
70
+ )
71
+
72
+ # Check values
73
+ if self.accelerator.is_main_process:
74
+ self.__check_basic_configs()
75
+ # Set runtime configs
76
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
77
+ self.checkpoints_path = [
78
+ [] for _ in range(len(self.save_checkpoint_stride))
79
+ ]
80
+ self.keep_last = [
81
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
82
+ ]
83
+ self.run_eval = self.cfg.train.run_eval
84
+
85
+ # set random seed
86
+ with self.accelerator.main_process_first():
87
+ start = time.monotonic_ns()
88
+ self._set_random_seed(self.cfg.train.random_seed)
89
+ end = time.monotonic_ns()
90
+ self.logger.debug(
91
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
92
+ )
93
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
94
+
95
+ # setup data_loader
96
+ with self.accelerator.main_process_first():
97
+ self.logger.info("Building dataset...")
98
+ start = time.monotonic_ns()
99
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
100
+ end = time.monotonic_ns()
101
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
102
+
103
+ # setup model
104
+ with self.accelerator.main_process_first():
105
+ self.logger.info("Building model...")
106
+ start = time.monotonic_ns()
107
+ self.model = self._build_model()
108
+ end = time.monotonic_ns()
109
+ self.logger.debug(self.model)
110
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
111
+ self.logger.info(
112
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
113
+ )
114
+ # optimizer & scheduler
115
+ with self.accelerator.main_process_first():
116
+ self.logger.info("Building optimizer and scheduler...")
117
+ start = time.monotonic_ns()
118
+ self.optimizer = self._build_optimizer()
119
+ self.scheduler = self._build_scheduler()
120
+ end = time.monotonic_ns()
121
+ self.logger.info(
122
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
123
+ )
124
+
125
+ # accelerate prepare
126
+ self.logger.info("Initializing accelerate...")
127
+ start = time.monotonic_ns()
128
+ self._accelerator_prepare()
129
+ end = time.monotonic_ns()
130
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
131
+
132
+ # create criterion
133
+ with self.accelerator.main_process_first():
134
+ self.logger.info("Building criterion...")
135
+ start = time.monotonic_ns()
136
+ self.criterion = self._build_criterion()
137
+ end = time.monotonic_ns()
138
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
139
+
140
+ # Resume or Finetune
141
+ with self.accelerator.main_process_first():
142
+ if args.resume:
143
+ if args.resume_from_ckpt_path == "":
144
+ ## Automatically resume according to the current exprimental name
145
+ self.logger.info(
146
+ "Automatically resuming from latest checkpoint in {}...".format(
147
+ self.checkpoint_dir
148
+ )
149
+ )
150
+ start = time.monotonic_ns()
151
+ ckpt_path = self._load_model(
152
+ checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
153
+ )
154
+ end = time.monotonic_ns()
155
+ self.logger.info(
156
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
157
+ )
158
+ self.checkpoints_path = json.load(
159
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
160
+ )
161
+ else:
162
+ ## Resume from the given checkpoint path
163
+ if not os.path.exists(args.resume_from_ckpt_path):
164
+ raise ValueError(
165
+ "[Error] The resumed checkpoint path {} don't exist.".format(
166
+ args.resume_from_ckpt_path
167
+ )
168
+ )
169
+ self.logger.info(
170
+ "Resuming from {}...".format(args.resume_from_ckpt_path)
171
+ )
172
+ start = time.monotonic_ns()
173
+ ckpt_path = self._load_model(
174
+ checkpoint_path=args.resume_from_ckpt_path,
175
+ resume_type=args.resume_type,
176
+ )
177
+ end = time.monotonic_ns()
178
+ self.logger.info(
179
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
180
+ )
181
+
182
+ # save config file path
183
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
184
+
185
+ def _accelerator_prepare(self):
186
+ (
187
+ self.train_dataloader,
188
+ self.valid_dataloader,
189
+ self.model,
190
+ self.optimizer,
191
+ self.scheduler,
192
+ ) = self.accelerator.prepare(
193
+ self.train_dataloader,
194
+ self.valid_dataloader,
195
+ self.model,
196
+ self.optimizer,
197
+ self.scheduler,
198
+ )
199
+
200
+ ### Following are abstract methods that should be implemented in child classes ###
201
+ @abstractmethod
202
+ def _build_dataset(self):
203
+ r"""Build dataset for model training/validating/evaluating."""
204
+ pass
205
+
206
+ @staticmethod
207
+ @abstractmethod
208
+ def _build_criterion():
209
+ r"""Build criterion function for model loss calculation."""
210
+ pass
211
+
212
+ @abstractmethod
213
+ def _build_model(self):
214
+ r"""Build model for training/validating/evaluating."""
215
+ pass
216
+
217
+ @abstractmethod
218
+ def _forward_step(self, batch):
219
+ r"""One forward step of the neural network. This abstract method is trying to
220
+ unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
221
+ However, for special case that using different forward step pattern for
222
+ training and validating, you could just override this method with ``pass`` and
223
+ implement ``_train_step`` and ``_valid_step`` separately.
224
+ """
225
+ pass
226
+
227
+ @abstractmethod
228
+ def _save_auxiliary_states(self):
229
+ r"""To save some auxiliary states when saving model's ckpt"""
230
+ pass
231
+
232
+ ### Abstract methods end ###
233
+
234
+ ### THIS IS MAIN ENTRY ###
235
+ def train_loop(self):
236
+ r"""Training loop. The public entry of training process."""
237
+ # Wait everyone to prepare before we move on
238
+ self.accelerator.wait_for_everyone()
239
+ # dump config file
240
+ if self.accelerator.is_main_process:
241
+ self.__dump_cfg(self.config_save_path)
242
+ self.model.train()
243
+ self.optimizer.zero_grad()
244
+ # Wait to ensure good to go
245
+ self.accelerator.wait_for_everyone()
246
+ while self.epoch < self.max_epoch:
247
+ self.logger.info("\n")
248
+ self.logger.info("-" * 32)
249
+ self.logger.info("Epoch {}: ".format(self.epoch))
250
+
251
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
252
+ ### It's inconvenient for the model with multiple losses
253
+ # Do training & validating epoch
254
+ train_loss = self._train_epoch()
255
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
256
+ valid_loss = self._valid_epoch()
257
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
258
+ self.accelerator.log(
259
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
260
+ step=self.epoch,
261
+ )
262
+
263
+ self.accelerator.wait_for_everyone()
264
+ # TODO: what is scheduler?
265
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
266
+
267
+ # Check if hit save_checkpoint_stride and run_eval
268
+ run_eval = False
269
+ if self.accelerator.is_main_process:
270
+ save_checkpoint = False
271
+ hit_dix = []
272
+ for i, num in enumerate(self.save_checkpoint_stride):
273
+ if self.epoch % num == 0:
274
+ save_checkpoint = True
275
+ hit_dix.append(i)
276
+ run_eval |= self.run_eval[i]
277
+
278
+ self.accelerator.wait_for_everyone()
279
+ if self.accelerator.is_main_process and save_checkpoint:
280
+ path = os.path.join(
281
+ self.checkpoint_dir,
282
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
283
+ self.epoch, self.step, train_loss
284
+ ),
285
+ )
286
+ self.tmp_checkpoint_save_path = path
287
+ self.accelerator.save_state(path)
288
+ print(f"save checkpoint in {path}")
289
+ json.dump(
290
+ self.checkpoints_path,
291
+ open(os.path.join(path, "ckpts.json"), "w"),
292
+ ensure_ascii=False,
293
+ indent=4,
294
+ )
295
+ self._save_auxiliary_states()
296
+
297
+ # Remove old checkpoints
298
+ to_remove = []
299
+ for idx in hit_dix:
300
+ self.checkpoints_path[idx].append(path)
301
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
302
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
303
+
304
+ # Search conflicts
305
+ total = set()
306
+ for i in self.checkpoints_path:
307
+ total |= set(i)
308
+ do_remove = set()
309
+ for idx, path in to_remove[::-1]:
310
+ if path in total:
311
+ self.checkpoints_path[idx].insert(0, path)
312
+ else:
313
+ do_remove.add(path)
314
+
315
+ # Remove old checkpoints
316
+ for path in do_remove:
317
+ shutil.rmtree(path, ignore_errors=True)
318
+ self.logger.debug(f"Remove old checkpoint: {path}")
319
+
320
+ self.accelerator.wait_for_everyone()
321
+ if run_eval:
322
+ # TODO: run evaluation
323
+ pass
324
+
325
+ # Update info for each epoch
326
+ self.epoch += 1
327
+
328
+ # Finish training and save final checkpoint
329
+ self.accelerator.wait_for_everyone()
330
+ if self.accelerator.is_main_process:
331
+ self.accelerator.save_state(
332
+ os.path.join(
333
+ self.checkpoint_dir,
334
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
335
+ self.epoch, self.step, valid_loss
336
+ ),
337
+ )
338
+ )
339
+ self._save_auxiliary_states()
340
+
341
+ self.accelerator.end_training()
342
+
343
+ ### Following are methods that can be used directly in child classes ###
344
+ def _train_epoch(self):
345
+ r"""Training epoch. Should return average loss of a batch (sample) over
346
+ one epoch. See ``train_loop`` for usage.
347
+ """
348
+ self.model.train()
349
+ epoch_sum_loss: float = 0.0
350
+ epoch_step: int = 0
351
+ for batch in tqdm(
352
+ self.train_dataloader,
353
+ desc=f"Training Epoch {self.epoch}",
354
+ unit="batch",
355
+ colour="GREEN",
356
+ leave=False,
357
+ dynamic_ncols=True,
358
+ smoothing=0.04,
359
+ disable=not self.accelerator.is_main_process,
360
+ ):
361
+ # Do training step and BP
362
+ with self.accelerator.accumulate(self.model):
363
+ loss = self._train_step(batch)
364
+ self.accelerator.backward(loss)
365
+ self.optimizer.step()
366
+ self.optimizer.zero_grad()
367
+ self.batch_count += 1
368
+
369
+ # Update info for each step
370
+ # TODO: step means BP counts or batch counts?
371
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
372
+ epoch_sum_loss += loss
373
+ self.accelerator.log(
374
+ {
375
+ "Step/Train Loss": loss,
376
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
377
+ },
378
+ step=self.step,
379
+ )
380
+ self.step += 1
381
+ epoch_step += 1
382
+
383
+ self.accelerator.wait_for_everyone()
384
+ return (
385
+ epoch_sum_loss
386
+ / len(self.train_dataloader)
387
+ * self.cfg.train.gradient_accumulation_step
388
+ )
389
+
390
+ @torch.inference_mode()
391
+ def _valid_epoch(self):
392
+ r"""Testing epoch. Should return average loss of a batch (sample) over
393
+ one epoch. See ``train_loop`` for usage.
394
+ """
395
+ self.model.eval()
396
+ epoch_sum_loss = 0.0
397
+ for batch in tqdm(
398
+ self.valid_dataloader,
399
+ desc=f"Validating Epoch {self.epoch}",
400
+ unit="batch",
401
+ colour="GREEN",
402
+ leave=False,
403
+ dynamic_ncols=True,
404
+ smoothing=0.04,
405
+ disable=not self.accelerator.is_main_process,
406
+ ):
407
+ batch_loss = self._valid_step(batch)
408
+ epoch_sum_loss += batch_loss.item()
409
+
410
+ self.accelerator.wait_for_everyone()
411
+ return epoch_sum_loss / len(self.valid_dataloader)
412
+
413
+ def _train_step(self, batch):
414
+ r"""Training forward step. Should return average loss of a sample over
415
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
416
+ See ``_train_epoch`` for usage.
417
+ """
418
+ return self._forward_step(batch)
419
+
420
+ @torch.inference_mode()
421
+ def _valid_step(self, batch):
422
+ r"""Testing forward step. Should return average loss of a sample over
423
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
424
+ See ``_test_epoch`` for usage.
425
+ """
426
+ return self._forward_step(batch)
427
+
428
+ def _load_model(
429
+ self,
430
+ checkpoint_dir: str = None,
431
+ checkpoint_path: str = None,
432
+ resume_type: str = "",
433
+ ):
434
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
435
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
436
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
437
+ method after** ``accelerator.prepare()``.
438
+ """
439
+ if checkpoint_path is None:
440
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
441
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
442
+ checkpoint_path = ls[0]
443
+ self.logger.info("Resume from {}...".format(checkpoint_path))
444
+
445
+ if resume_type in ["resume", ""]:
446
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
447
+ self.accelerator.load_state(input_dir=checkpoint_path)
448
+
449
+ # set epoch and step
450
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
451
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
452
+
453
+ elif resume_type == "finetune":
454
+ # Load only the model weights
455
+ accelerate.load_checkpoint_and_dispatch(
456
+ self.accelerator.unwrap_model(self.model),
457
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
458
+ )
459
+ self.logger.info("Load model weights for finetune...")
460
+
461
+ else:
462
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
463
+
464
+ return checkpoint_path
465
+
466
+ def _build_dataloader(self):
467
+ Dataset, Collator = self._build_dataset()
468
+
469
+ # build dataset instance for each dataset and combine them by ConcatDataset
470
+ datasets_list = []
471
+ for dataset in self.cfg.dataset:
472
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
473
+ datasets_list.append(subdataset)
474
+ train_dataset = ConcatDataset(datasets_list)
475
+ train_collate = Collator(self.cfg)
476
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
477
+ self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
478
+ self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
479
+ # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
480
+ train_loader = DataLoader(
481
+ train_dataset,
482
+ # shuffle=True,
483
+ collate_fn=train_collate,
484
+ batch_sampler=batch_sampler,
485
+ num_workers=self.cfg.train.dataloader.num_worker,
486
+ pin_memory=self.cfg.train.dataloader.pin_memory,
487
+ )
488
+
489
+ # Build valid dataloader
490
+ datasets_list = []
491
+ for dataset in self.cfg.dataset:
492
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
493
+ datasets_list.append(subdataset)
494
+ valid_dataset = ConcatDataset(datasets_list)
495
+ valid_collate = Collator(self.cfg)
496
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
497
+ self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
498
+ self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
499
+ valid_loader = DataLoader(
500
+ valid_dataset,
501
+ collate_fn=valid_collate,
502
+ batch_sampler=batch_sampler,
503
+ num_workers=self.cfg.train.dataloader.num_worker,
504
+ pin_memory=self.cfg.train.dataloader.pin_memory,
505
+ )
506
+ return train_loader, valid_loader
507
+
508
+ @staticmethod
509
+ def _set_random_seed(seed):
510
+ r"""Set random seed for all possible random modules."""
511
+ random.seed(seed)
512
+ np.random.seed(seed)
513
+ torch.random.manual_seed(seed)
514
+
515
+ def _check_nan(self, loss, y_pred, y_gt):
516
+ if torch.any(torch.isnan(loss)):
517
+ self.logger.error("Fatal Error: Training is down since loss has Nan!")
518
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
519
+
520
+ ### y_pred ###
521
+ if torch.any(torch.isnan(y_pred)):
522
+ self.logger.error(
523
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
524
+ )
525
+ self.logger.error(f"y_pred: {y_pred}", in_order=True)
526
+ else:
527
+ self.logger.debug(
528
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
529
+ )
530
+ self.logger.debug(f"y_pred: {y_pred}", in_order=True)
531
+
532
+ ### y_gt ###
533
+ if torch.any(torch.isnan(y_gt)):
534
+ self.logger.error(
535
+ f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
536
+ )
537
+ self.logger.error(f"y_gt: {y_gt}", in_order=True)
538
+ else:
539
+ self.logger.debug(
540
+ f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
541
+ )
542
+ self.logger.debug(f"y_gt: {y_gt}", in_order=True)
543
+
544
+ self.accelerator.end_training()
545
+ raise RuntimeError("Loss has Nan! See log for more info.")
546
+
547
+ ### Protected methods end ###
548
+
549
+ ## Following are private methods ##
550
+ def _build_optimizer(self):
551
+ r"""Build optimizer for model."""
552
+ # Make case-insensitive matching
553
+ if self.cfg.train.optimizer.lower() == "adadelta":
554
+ optimizer = torch.optim.Adadelta(
555
+ self.model.parameters(), **self.cfg.train.adadelta
556
+ )
557
+ self.logger.info("Using Adadelta optimizer.")
558
+ elif self.cfg.train.optimizer.lower() == "adagrad":
559
+ optimizer = torch.optim.Adagrad(
560
+ self.model.parameters(), **self.cfg.train.adagrad
561
+ )
562
+ self.logger.info("Using Adagrad optimizer.")
563
+ elif self.cfg.train.optimizer.lower() == "adam":
564
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
565
+ self.logger.info("Using Adam optimizer.")
566
+ elif self.cfg.train.optimizer.lower() == "adamw":
567
+ optimizer = torch.optim.AdamW(
568
+ self.model.parameters(), **self.cfg.train.adamw
569
+ )
570
+ elif self.cfg.train.optimizer.lower() == "sparseadam":
571
+ optimizer = torch.optim.SparseAdam(
572
+ self.model.parameters(), **self.cfg.train.sparseadam
573
+ )
574
+ elif self.cfg.train.optimizer.lower() == "adamax":
575
+ optimizer = torch.optim.Adamax(
576
+ self.model.parameters(), **self.cfg.train.adamax
577
+ )
578
+ elif self.cfg.train.optimizer.lower() == "asgd":
579
+ optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
580
+ elif self.cfg.train.optimizer.lower() == "lbfgs":
581
+ optimizer = torch.optim.LBFGS(
582
+ self.model.parameters(), **self.cfg.train.lbfgs
583
+ )
584
+ elif self.cfg.train.optimizer.lower() == "nadam":
585
+ optimizer = torch.optim.NAdam(
586
+ self.model.parameters(), **self.cfg.train.nadam
587
+ )
588
+ elif self.cfg.train.optimizer.lower() == "radam":
589
+ optimizer = torch.optim.RAdam(
590
+ self.model.parameters(), **self.cfg.train.radam
591
+ )
592
+ elif self.cfg.train.optimizer.lower() == "rmsprop":
593
+ optimizer = torch.optim.RMSprop(
594
+ self.model.parameters(), **self.cfg.train.rmsprop
595
+ )
596
+ elif self.cfg.train.optimizer.lower() == "rprop":
597
+ optimizer = torch.optim.Rprop(
598
+ self.model.parameters(), **self.cfg.train.rprop
599
+ )
600
+ elif self.cfg.train.optimizer.lower() == "sgd":
601
+ optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
602
+ else:
603
+ raise NotImplementedError(
604
+ f"Optimizer {self.cfg.train.optimizer} not supported yet!"
605
+ )
606
+ return optimizer
607
+
608
+ def _build_scheduler(self):
609
+ r"""Build scheduler for optimizer."""
610
+ # Make case-insensitive matching
611
+ if self.cfg.train.scheduler.lower() == "lambdalr":
612
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
613
+ self.optimizer, **self.cfg.train.lambdalr
614
+ )
615
+ elif self.cfg.train.scheduler.lower() == "multiplicativelr":
616
+ scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
617
+ self.optimizer, **self.cfg.train.multiplicativelr
618
+ )
619
+ elif self.cfg.train.scheduler.lower() == "steplr":
620
+ scheduler = torch.optim.lr_scheduler.StepLR(
621
+ self.optimizer, **self.cfg.train.steplr
622
+ )
623
+ elif self.cfg.train.scheduler.lower() == "multisteplr":
624
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
625
+ self.optimizer, **self.cfg.train.multisteplr
626
+ )
627
+ elif self.cfg.train.scheduler.lower() == "constantlr":
628
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
629
+ self.optimizer, **self.cfg.train.constantlr
630
+ )
631
+ elif self.cfg.train.scheduler.lower() == "linearlr":
632
+ scheduler = torch.optim.lr_scheduler.LinearLR(
633
+ self.optimizer, **self.cfg.train.linearlr
634
+ )
635
+ elif self.cfg.train.scheduler.lower() == "exponentiallr":
636
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
637
+ self.optimizer, **self.cfg.train.exponentiallr
638
+ )
639
+ elif self.cfg.train.scheduler.lower() == "polynomiallr":
640
+ scheduler = torch.optim.lr_scheduler.PolynomialLR(
641
+ self.optimizer, **self.cfg.train.polynomiallr
642
+ )
643
+ elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
644
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
645
+ self.optimizer, **self.cfg.train.cosineannealinglr
646
+ )
647
+ elif self.cfg.train.scheduler.lower() == "sequentiallr":
648
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
649
+ self.optimizer, **self.cfg.train.sequentiallr
650
+ )
651
+ elif self.cfg.train.scheduler.lower() == "reducelronplateau":
652
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
653
+ self.optimizer, **self.cfg.train.reducelronplateau
654
+ )
655
+ elif self.cfg.train.scheduler.lower() == "cycliclr":
656
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
657
+ self.optimizer, **self.cfg.train.cycliclr
658
+ )
659
+ elif self.cfg.train.scheduler.lower() == "onecyclelr":
660
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
661
+ self.optimizer, **self.cfg.train.onecyclelr
662
+ )
663
+ elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
664
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
665
+ self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
666
+ )
667
+ elif self.cfg.train.scheduler.lower() == "noamlr":
668
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
669
+ else:
670
+ raise NotImplementedError(
671
+ f"Scheduler {self.cfg.train.scheduler} not supported yet!"
672
+ )
673
+ return scheduler
674
+
675
+ def _init_accelerator(self):
676
+ self.exp_dir = os.path.join(
677
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
678
+ )
679
+ project_config = ProjectConfiguration(
680
+ project_dir=self.exp_dir,
681
+ logging_dir=os.path.join(self.exp_dir, "log"),
682
+ )
683
+ self.accelerator = accelerate.Accelerator(
684
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
685
+ log_with=self.cfg.train.tracker,
686
+ project_config=project_config,
687
+ )
688
+ if self.accelerator.is_main_process:
689
+ os.makedirs(project_config.project_dir, exist_ok=True)
690
+ os.makedirs(project_config.logging_dir, exist_ok=True)
691
+ with self.accelerator.main_process_first():
692
+ self.accelerator.init_trackers(self.args.exp_name)
693
+
694
+ def __check_basic_configs(self):
695
+ if self.cfg.train.gradient_accumulation_step <= 0:
696
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
697
+ self.logger.error(
698
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
699
+ )
700
+ self.accelerator.end_training()
701
+ raise ValueError(
702
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
703
+ )
704
+ # TODO: check other values
705
+
706
+ @staticmethod
707
+ def __count_parameters(model):
708
+ model_param = 0.0
709
+ if isinstance(model, dict):
710
+ for key, value in model.items():
711
+ model_param += sum(p.numel() for p in model[key].parameters())
712
+ else:
713
+ model_param = sum(p.numel() for p in model.parameters())
714
+ return model_param
715
+
716
+ def __dump_cfg(self, path):
717
+ os.makedirs(os.path.dirname(path), exist_ok=True)
718
+ json5.dump(
719
+ self.cfg,
720
+ open(path, "w"),
721
+ indent=4,
722
+ sort_keys=True,
723
+ ensure_ascii=False,
724
+ quote_keys=True,
725
+ )
726
+
727
+ ### Private methods end ###
models/codec/__init__.py ADDED
File without changes
models/codec/amphion_codec/codec.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from torch.nn.utils import weight_norm
13
+
14
+ from models.codec.amphion_codec.quantize import (
15
+ ResidualVQ,
16
+ VectorQuantize,
17
+ FactorizedVectorQuantize,
18
+ LookupFreeQuantize,
19
+ )
20
+
21
+ from models.codec.amphion_codec.vocos import Vocos
22
+
23
+
24
+ def WNConv1d(*args, **kwargs):
25
+ return weight_norm(nn.Conv1d(*args, **kwargs))
26
+
27
+
28
+ def WNConvTranspose1d(*args, **kwargs):
29
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
30
+
31
+
32
+ # Scripting this brings model speed up 1.4x
33
+ @torch.jit.script
34
+ def snake(x, alpha):
35
+ shape = x.shape
36
+ x = x.reshape(shape[0], shape[1], -1)
37
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
38
+ x = x.reshape(shape)
39
+ return x
40
+
41
+
42
+ class Snake1d(nn.Module):
43
+ def __init__(self, channels):
44
+ super().__init__()
45
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
46
+
47
+ def forward(self, x):
48
+ return snake(x, self.alpha)
49
+
50
+
51
+ def init_weights(m):
52
+ if isinstance(m, nn.Conv1d):
53
+ nn.init.trunc_normal_(m.weight, std=0.02)
54
+ nn.init.constant_(m.bias, 0)
55
+ if isinstance(m, nn.Linear):
56
+ nn.init.trunc_normal_(m.weight, std=0.02)
57
+ nn.init.constant_(m.bias, 0)
58
+
59
+
60
+ class ResidualUnit(nn.Module):
61
+ def __init__(self, dim: int = 16, dilation: int = 1):
62
+ super().__init__()
63
+ pad = ((7 - 1) * dilation) // 2
64
+ self.block = nn.Sequential(
65
+ Snake1d(dim),
66
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
67
+ Snake1d(dim),
68
+ WNConv1d(dim, dim, kernel_size=1),
69
+ )
70
+
71
+ def forward(self, x):
72
+ y = self.block(x)
73
+ pad = (x.shape[-1] - y.shape[-1]) // 2
74
+ if pad > 0:
75
+ x = x[..., pad:-pad]
76
+ return x + y
77
+
78
+
79
+ class EncoderBlock(nn.Module):
80
+ def __init__(self, dim: int = 16, stride: int = 1):
81
+ super().__init__()
82
+ self.block = nn.Sequential(
83
+ ResidualUnit(dim // 2, dilation=1),
84
+ ResidualUnit(dim // 2, dilation=3),
85
+ ResidualUnit(dim // 2, dilation=9),
86
+ Snake1d(dim // 2),
87
+ WNConv1d(
88
+ dim // 2,
89
+ dim,
90
+ kernel_size=2 * stride,
91
+ stride=stride,
92
+ padding=math.ceil(stride / 2),
93
+ ),
94
+ )
95
+
96
+ def forward(self, x):
97
+ return self.block(x)
98
+
99
+
100
+ class CodecEncoder(nn.Module):
101
+ def __init__(
102
+ self,
103
+ d_model: int = 64,
104
+ up_ratios: list = [4, 5, 5, 6],
105
+ out_channels: int = 256,
106
+ use_tanh: bool = False,
107
+ cfg=None,
108
+ ):
109
+ super().__init__()
110
+
111
+ d_model = cfg.d_model if cfg is not None else d_model
112
+ up_ratios = cfg.up_ratios if cfg is not None else up_ratios
113
+ out_channels = cfg.out_channels if cfg is not None else out_channels
114
+ use_tanh = cfg.use_tanh if cfg is not None else use_tanh
115
+
116
+ # Create first convolution
117
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
118
+
119
+ # Create EncoderBlocks that double channels as they downsample by `stride`
120
+ for stride in up_ratios:
121
+ d_model *= 2
122
+ self.block += [EncoderBlock(d_model, stride=stride)]
123
+
124
+ # Create last convolution
125
+ self.block += [
126
+ Snake1d(d_model),
127
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
128
+ ]
129
+
130
+ if use_tanh:
131
+ self.block += [nn.Tanh()]
132
+
133
+ # Wrap black into nn.Sequential
134
+ self.block = nn.Sequential(*self.block)
135
+ self.enc_dim = d_model
136
+
137
+ self.reset_parameters()
138
+
139
+ def forward(self, x):
140
+ return self.block(x)
141
+
142
+ def reset_parameters(self):
143
+ self.apply(init_weights)
144
+
145
+
146
+ class DecoderBlock(nn.Module):
147
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
148
+ super().__init__()
149
+ self.block = nn.Sequential(
150
+ Snake1d(input_dim),
151
+ WNConvTranspose1d(
152
+ input_dim,
153
+ output_dim,
154
+ kernel_size=2 * stride,
155
+ stride=stride,
156
+ padding=stride // 2 + stride % 2,
157
+ output_padding=stride % 2,
158
+ ),
159
+ ResidualUnit(output_dim, dilation=1),
160
+ ResidualUnit(output_dim, dilation=3),
161
+ ResidualUnit(output_dim, dilation=9),
162
+ )
163
+
164
+ def forward(self, x):
165
+ return self.block(x)
166
+
167
+
168
+ class CodecDecoder(nn.Module):
169
+ def __init__(
170
+ self,
171
+ in_channels: int = 256,
172
+ upsample_initial_channel: int = 1536,
173
+ up_ratios: list = [5, 5, 4, 2],
174
+ num_quantizers: int = 8,
175
+ codebook_size: int = 1024,
176
+ codebook_dim: int = 256,
177
+ quantizer_type: str = "vq",
178
+ quantizer_dropout: float = 0.5,
179
+ commitment: float = 0.25,
180
+ codebook_loss_weight: float = 1.0,
181
+ use_l2_normlize: bool = False,
182
+ codebook_type: str = "euclidean",
183
+ kmeans_init: bool = False,
184
+ kmeans_iters: int = 10,
185
+ decay: float = 0.8,
186
+ eps: float = 1e-5,
187
+ threshold_ema_dead_code: int = 2,
188
+ weight_init: bool = False,
189
+ use_vocos: bool = False,
190
+ vocos_dim: int = 384,
191
+ vocos_intermediate_dim: int = 1152,
192
+ vocos_num_layers: int = 8,
193
+ n_fft: int = 800,
194
+ hop_size: int = 200,
195
+ padding: str = "same",
196
+ cfg=None,
197
+ ):
198
+ super().__init__()
199
+
200
+ in_channels = (
201
+ cfg.in_channels
202
+ if cfg is not None and hasattr(cfg, "in_channels")
203
+ else in_channels
204
+ )
205
+ upsample_initial_channel = (
206
+ cfg.upsample_initial_channel
207
+ if cfg is not None and hasattr(cfg, "upsample_initial_channel")
208
+ else upsample_initial_channel
209
+ )
210
+ up_ratios = (
211
+ cfg.up_ratios
212
+ if cfg is not None and hasattr(cfg, "up_ratios")
213
+ else up_ratios
214
+ )
215
+ num_quantizers = (
216
+ cfg.num_quantizers
217
+ if cfg is not None and hasattr(cfg, "num_quantizers")
218
+ else num_quantizers
219
+ )
220
+ codebook_size = (
221
+ cfg.codebook_size
222
+ if cfg is not None and hasattr(cfg, "codebook_size")
223
+ else codebook_size
224
+ )
225
+ codebook_dim = (
226
+ cfg.codebook_dim
227
+ if cfg is not None and hasattr(cfg, "codebook_dim")
228
+ else codebook_dim
229
+ )
230
+ quantizer_type = (
231
+ cfg.quantizer_type
232
+ if cfg is not None and hasattr(cfg, "quantizer_type")
233
+ else quantizer_type
234
+ )
235
+ quantizer_dropout = (
236
+ cfg.quantizer_dropout
237
+ if cfg is not None and hasattr(cfg, "quantizer_dropout")
238
+ else quantizer_dropout
239
+ )
240
+ commitment = (
241
+ cfg.commitment
242
+ if cfg is not None and hasattr(cfg, "commitment")
243
+ else commitment
244
+ )
245
+ codebook_loss_weight = (
246
+ cfg.codebook_loss_weight
247
+ if cfg is not None and hasattr(cfg, "codebook_loss_weight")
248
+ else codebook_loss_weight
249
+ )
250
+ use_l2_normlize = (
251
+ cfg.use_l2_normlize
252
+ if cfg is not None and hasattr(cfg, "use_l2_normlize")
253
+ else use_l2_normlize
254
+ )
255
+ codebook_type = (
256
+ cfg.codebook_type
257
+ if cfg is not None and hasattr(cfg, "codebook_type")
258
+ else codebook_type
259
+ )
260
+ kmeans_init = (
261
+ cfg.kmeans_init
262
+ if cfg is not None and hasattr(cfg, "kmeans_init")
263
+ else kmeans_init
264
+ )
265
+ kmeans_iters = (
266
+ cfg.kmeans_iters
267
+ if cfg is not None and hasattr(cfg, "kmeans_iters")
268
+ else kmeans_iters
269
+ )
270
+ decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
271
+ eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
272
+ threshold_ema_dead_code = (
273
+ cfg.threshold_ema_dead_code
274
+ if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
275
+ else threshold_ema_dead_code
276
+ )
277
+ weight_init = (
278
+ cfg.weight_init
279
+ if cfg is not None and hasattr(cfg, "weight_init")
280
+ else weight_init
281
+ )
282
+ use_vocos = (
283
+ cfg.use_vocos
284
+ if cfg is not None and hasattr(cfg, "use_vocos")
285
+ else use_vocos
286
+ )
287
+ vocos_dim = (
288
+ cfg.vocos_dim
289
+ if cfg is not None and hasattr(cfg, "vocos_dim")
290
+ else vocos_dim
291
+ )
292
+ vocos_intermediate_dim = (
293
+ cfg.vocos_intermediate_dim
294
+ if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
295
+ else vocos_intermediate_dim
296
+ )
297
+ vocos_num_layers = (
298
+ cfg.vocos_num_layers
299
+ if cfg is not None and hasattr(cfg, "vocos_num_layers")
300
+ else vocos_num_layers
301
+ )
302
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
303
+ hop_size = (
304
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
305
+ )
306
+ padding = (
307
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
308
+ )
309
+
310
+ if quantizer_type == "vq":
311
+ self.quantizer = ResidualVQ(
312
+ input_dim=in_channels,
313
+ num_quantizers=num_quantizers,
314
+ codebook_size=codebook_size,
315
+ codebook_dim=codebook_dim,
316
+ quantizer_type=quantizer_type,
317
+ quantizer_dropout=quantizer_dropout,
318
+ commitment=commitment,
319
+ codebook_loss_weight=codebook_loss_weight,
320
+ use_l2_normlize=use_l2_normlize,
321
+ codebook_type=codebook_type,
322
+ kmeans_init=kmeans_init,
323
+ kmeans_iters=kmeans_iters,
324
+ decay=decay,
325
+ eps=eps,
326
+ threshold_ema_dead_code=threshold_ema_dead_code,
327
+ weight_init=weight_init,
328
+ )
329
+ elif quantizer_type == "fvq":
330
+ self.quantizer = ResidualVQ(
331
+ input_dim=in_channels,
332
+ num_quantizers=num_quantizers,
333
+ codebook_size=codebook_size,
334
+ codebook_dim=codebook_dim,
335
+ quantizer_type=quantizer_type,
336
+ quantizer_dropout=quantizer_dropout,
337
+ commitment=commitment,
338
+ codebook_loss_weight=codebook_loss_weight,
339
+ use_l2_normlize=use_l2_normlize,
340
+ )
341
+ elif quantizer_type == "lfq":
342
+ self.quantizer = ResidualVQ(
343
+ input_dim=in_channels,
344
+ num_quantizers=num_quantizers,
345
+ codebook_size=codebook_size,
346
+ codebook_dim=codebook_dim,
347
+ quantizer_type=quantizer_type,
348
+ )
349
+ else:
350
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
351
+
352
+ if not use_vocos:
353
+ # Add first conv layer
354
+ channels = upsample_initial_channel
355
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
356
+
357
+ # Add upsampling + MRF blocks
358
+ for i, stride in enumerate(up_ratios):
359
+ input_dim = channels // 2**i
360
+ output_dim = channels // 2 ** (i + 1)
361
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
362
+
363
+ # Add final conv layer
364
+ layers += [
365
+ Snake1d(output_dim),
366
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
367
+ nn.Tanh(),
368
+ ]
369
+
370
+ self.model = nn.Sequential(*layers)
371
+
372
+ if use_vocos:
373
+ self.model = Vocos(
374
+ input_channels=in_channels,
375
+ dim=vocos_dim,
376
+ intermediate_dim=vocos_intermediate_dim,
377
+ num_layers=vocos_num_layers,
378
+ adanorm_num_embeddings=None,
379
+ n_fft=n_fft,
380
+ hop_size=hop_size,
381
+ padding=padding,
382
+ )
383
+
384
+ self.reset_parameters()
385
+
386
+ def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
387
+ """
388
+ if vq is True, x = encoder output, then return quantized output;
389
+ else, x = quantized output, then return decoder output
390
+ """
391
+ if vq is True:
392
+ if eval_vq:
393
+ self.quantizer.eval()
394
+ (
395
+ quantized_out,
396
+ all_indices,
397
+ all_commit_losses,
398
+ all_codebook_losses,
399
+ all_quantized,
400
+ ) = self.quantizer(x, n_quantizers=n_quantizers)
401
+ return (
402
+ quantized_out,
403
+ all_indices,
404
+ all_commit_losses,
405
+ all_codebook_losses,
406
+ all_quantized,
407
+ )
408
+
409
+ return self.model(x)
410
+
411
+ def quantize(self, x, n_quantizers=None):
412
+ self.quantizer.eval()
413
+ quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
414
+ return quantized_out, vq
415
+
416
+ # TODO: check consistency of vq2emb and quantize
417
+ def vq2emb(self, vq, n_quantizers=None):
418
+ return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
419
+
420
+ def decode(self, x):
421
+ return self.model(x)
422
+
423
+ def latent2dist(self, x, n_quantizers=None):
424
+ return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
425
+
426
+ def reset_parameters(self):
427
+ self.apply(init_weights)
models/codec/amphion_codec/quantize/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
7
+ FactorizedVectorQuantize,
8
+ )
9
+ from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
10
+ from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
11
+ from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
models/codec/amphion_codec/quantize/factorized_vector_quantize.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ class FactorizedVectorQuantize(nn.Module):
23
+ def __init__(
24
+ self,
25
+ input_dim,
26
+ codebook_size,
27
+ codebook_dim,
28
+ commitment=0.005,
29
+ codebook_loss_weight=1.0,
30
+ use_l2_normlize=True,
31
+ ):
32
+ super().__init__()
33
+ self.input_dim = input_dim
34
+ self.codebook_size = codebook_size
35
+ self.codebook_dim = codebook_dim
36
+ self.commitment = commitment
37
+ self.codebook_loss_weight = codebook_loss_weight
38
+ self.use_l2_normlize = use_l2_normlize
39
+
40
+ if self.input_dim != self.codebook_dim:
41
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
42
+ self.out_project = WNConv1d(
43
+ self.codebook_dim, self.input_dim, kernel_size=1
44
+ )
45
+
46
+ else:
47
+ self.in_project = nn.Identity()
48
+ self.out_project = nn.Identity()
49
+
50
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
51
+
52
+ def forward(self, z):
53
+ """
54
+ Parameters
55
+ ----------
56
+ z: torch.Tensor[B x D x T]
57
+
58
+ Returns
59
+ -------
60
+ z_q: torch.Tensor[B x D x T]
61
+ Quantized continuous representation of input
62
+ commit_loss: Tensor[B]
63
+ Commitment loss to train encoder to predict vectors closer to codebook entries
64
+ codebook_loss: Tensor[B]
65
+ Codebook loss to update the codebook
66
+ indices: torch.Tensor[B x T]
67
+ Codebook indices (quantized discrete representation of input)
68
+ z_e: torch.Tensor[B x D x T]
69
+ Projected latents (continuous representation of input before quantization)
70
+ """
71
+
72
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
73
+ z_e = self.in_project(z)
74
+ z_q, indices = self.decode_latents(z_e)
75
+
76
+ # Compute commitment loss and codebook loss
77
+ if self.training:
78
+ commit_loss = (
79
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
80
+ * self.commitment
81
+ )
82
+ codebook_loss = (
83
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
84
+ * self.codebook_loss_weight
85
+ )
86
+ else:
87
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
88
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
89
+
90
+ z_q = z_e + (z_q - z_e).detach()
91
+
92
+ z_q = self.out_project(z_q)
93
+
94
+ return z_q, commit_loss, codebook_loss, indices, z_e
95
+
96
+ def embed_code(self, embed_id):
97
+ return F.embedding(embed_id, self.codebook.weight)
98
+
99
+ def decode_code(self, embed_id):
100
+ return self.embed_code(embed_id).transpose(1, 2)
101
+
102
+ def decode_latents(self, latents):
103
+ encodings = rearrange(latents, "b d t -> (b t) d")
104
+ codebook = self.codebook.weight
105
+
106
+ # L2 normalize encodings and codebook
107
+ if self.use_l2_normlize:
108
+ encodings = F.normalize(encodings)
109
+ codebook = F.normalize(codebook)
110
+
111
+ # Compute euclidean distance between encodings and codebook,
112
+ # if use_l2_normlize is True, the distance is equal to cosine distance
113
+ dist = (
114
+ encodings.pow(2).sum(1, keepdim=True)
115
+ - 2 * encodings @ codebook.t()
116
+ + codebook.pow(2).sum(1, keepdim=True).t()
117
+ )
118
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
119
+ z_q = self.decode_code(indices)
120
+
121
+ return z_q, indices
122
+
123
+ def vq2emb(self, vq, out_proj=True):
124
+ emb = self.decode_code(vq)
125
+ if out_proj:
126
+ emb = self.out_project(emb)
127
+ return emb
128
+
129
+ def latent2dist(self, latents):
130
+ encodings = rearrange(latents, "b d t -> (b t) d")
131
+ codebook = self.codebook.weight
132
+
133
+ # L2 normalize encodings and codebook
134
+ if self.use_l2_normlize:
135
+ encodings = F.normalize(encodings)
136
+ codebook = F.normalize(codebook)
137
+
138
+ # Compute euclidean distance between encodings and codebook,
139
+ # if use_l2_normlize is True, the distance is equal to cosine distance
140
+ dist = (
141
+ encodings.pow(2).sum(1, keepdim=True)
142
+ - 2 * encodings @ codebook.t()
143
+ + codebook.pow(2).sum(1, keepdim=True).t()
144
+ ) # (b*t, k)
145
+
146
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
147
+ dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
148
+ z_q = self.decode_code(indices)
149
+
150
+ return -dist, indices, z_q
models/codec/amphion_codec/quantize/lookup_free_quantize.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ class LookupFreeQuantize(nn.Module):
23
+ def __init__(
24
+ self,
25
+ input_dim,
26
+ codebook_size,
27
+ codebook_dim,
28
+ ):
29
+ super().__init__()
30
+ self.input_dim = input_dim
31
+ self.codebook_size = codebook_size
32
+ self.codebook_dim = codebook_dim
33
+
34
+ assert 2**codebook_dim == codebook_size
35
+
36
+ if self.input_dim != self.codebook_dim:
37
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
38
+ self.out_project = WNConv1d(
39
+ self.codebook_dim, self.input_dim, kernel_size=1
40
+ )
41
+
42
+ else:
43
+ self.in_project = nn.Identity()
44
+ self.out_project = nn.Identity()
45
+
46
+ def forward(self, z):
47
+ z_e = self.in_project(z)
48
+ z_e = F.sigmoid(z_e)
49
+
50
+ z_q = z_e + (torch.round(z_e) - z_e).detach()
51
+
52
+ z_q = self.out_project(z_q)
53
+
54
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
55
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
56
+
57
+ bits = (
58
+ 2
59
+ ** torch.arange(self.codebook_dim, device=z.device)
60
+ .unsqueeze(0)
61
+ .unsqueeze(-1)
62
+ .long()
63
+ ) # (1, d, 1)
64
+ indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
65
+
66
+ return z_q, commit_loss, codebook_loss, indices, z_e
67
+
68
+ def vq2emb(self, vq, out_proj=True):
69
+ emb = torch.zeros(
70
+ vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
71
+ ) # (B, d, T)
72
+ for i in range(self.codebook_dim):
73
+ emb[:, i, :] = (vq % 2).float()
74
+ vq = vq // 2
75
+ if out_proj:
76
+ emb = self.out_project(emb)
77
+ return emb
models/codec/amphion_codec/quantize/residual_vq.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from torch.nn.utils import weight_norm
14
+
15
+ from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
16
+ FactorizedVectorQuantize,
17
+ )
18
+ from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
19
+ from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
20
+
21
+
22
+ class ResidualVQ(nn.Module):
23
+ """
24
+ Introduced in SoundStream: An end2end neural audio codec
25
+ https://arxiv.org/abs/2107.03312
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ input_dim: int = 256,
31
+ num_quantizers: int = 8,
32
+ codebook_size: int = 1024,
33
+ codebook_dim: int = 256,
34
+ quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
35
+ quantizer_dropout: float = 0.5,
36
+ **kwargs,
37
+ ):
38
+ super().__init__()
39
+
40
+ self.input_dim = input_dim
41
+ self.num_quantizers = num_quantizers
42
+ self.codebook_size = codebook_size
43
+ self.codebook_dim = codebook_dim
44
+ self.quantizer_type = quantizer_type
45
+ self.quantizer_dropout = quantizer_dropout
46
+
47
+ if quantizer_type == "vq":
48
+ VQ = VectorQuantize
49
+ elif quantizer_type == "fvq":
50
+ VQ = FactorizedVectorQuantize
51
+ elif quantizer_type == "lfq":
52
+ VQ = LookupFreeQuantize
53
+ else:
54
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
55
+
56
+ self.quantizers = nn.ModuleList(
57
+ [
58
+ VQ(
59
+ input_dim=input_dim,
60
+ codebook_size=codebook_size,
61
+ codebook_dim=codebook_dim,
62
+ **kwargs,
63
+ )
64
+ for _ in range(num_quantizers)
65
+ ]
66
+ )
67
+
68
+ def forward(self, z, n_quantizers: int = None):
69
+ """
70
+ Parameters
71
+ ----------
72
+ z : Tensor[B x D x T]
73
+ n_quantizers : int, optional
74
+ No. of quantizers to use
75
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
76
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
77
+ when in training mode, and a random number of quantizers is used.
78
+ Returns
79
+ -------
80
+ "quantized_out" : Tensor[B x D x T]
81
+ Quantized continuous representation of input
82
+ "all_indices" : Tensor[N x B x T]
83
+ Codebook indices for each codebook
84
+ (quantized discrete representation of input)
85
+ "all_commit_losses" : Tensor[N]
86
+ "all_codebook_losses" : Tensor[N]
87
+ "all_quantized" : Tensor[N x B x D x T]
88
+ """
89
+
90
+ quantized_out = 0.0
91
+ residual = z
92
+
93
+ all_commit_losses = []
94
+ all_codebook_losses = []
95
+ all_indices = []
96
+ all_quantized = []
97
+
98
+ if n_quantizers is None:
99
+ n_quantizers = self.num_quantizers
100
+
101
+ if self.training:
102
+ n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
103
+ dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
104
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
105
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
106
+ n_quantizers = n_quantizers.to(z.device)
107
+
108
+ for i, quantizer in enumerate(self.quantizers):
109
+ if self.training is False and i >= n_quantizers:
110
+ break
111
+
112
+ z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
113
+ residual
114
+ )
115
+
116
+ # Create mask to apply quantizer dropout
117
+ mask = (
118
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
119
+ )
120
+ quantized_out = quantized_out + z_q_i * mask[:, None, None]
121
+ residual = residual - z_q_i
122
+
123
+ commit_loss_i = (commit_loss_i * mask).mean()
124
+ codebook_loss_i = (codebook_loss_i * mask).mean()
125
+
126
+ all_commit_losses.append(commit_loss_i)
127
+ all_codebook_losses.append(codebook_loss_i)
128
+ all_indices.append(indices_i)
129
+ all_quantized.append(z_q_i)
130
+
131
+ all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
132
+ torch.stack,
133
+ (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
134
+ )
135
+
136
+ return (
137
+ quantized_out,
138
+ all_indices,
139
+ all_commit_losses,
140
+ all_codebook_losses,
141
+ all_quantized,
142
+ )
143
+
144
+ def vq2emb(self, vq, n_quantizers=None):
145
+ quantized_out = 0.0
146
+ if n_quantizers is None:
147
+ n_quantizers = self.num_quantizers
148
+ for idx, quantizer in enumerate(self.quantizers):
149
+ if idx >= n_quantizers:
150
+ break
151
+ quantized_out += quantizer.vq2emb(vq[idx])
152
+ return quantized_out
153
+
154
+ def latent2dist(self, z, n_quantizers=None):
155
+ quantized_out = 0.0
156
+ residual = z
157
+
158
+ all_dists = []
159
+ all_indices = []
160
+
161
+ if n_quantizers is None:
162
+ n_quantizers = self.num_quantizers
163
+
164
+ for i, quantizer in enumerate(self.quantizers):
165
+ if self.training is False and i >= n_quantizers:
166
+ break
167
+ dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
168
+ all_dists.append(dist_i)
169
+ all_indices.append(indices_i)
170
+
171
+ quantized_out = quantized_out + z_q_i
172
+ residual = residual - z_q_i
173
+
174
+ all_dists = torch.stack(all_dists)
175
+ all_indices = torch.stack(all_indices)
176
+
177
+ return all_dists, all_indices
models/codec/amphion_codec/quantize/vector_quantize.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from torch.nn.utils import weight_norm
12
+
13
+
14
+ def WNConv1d(*args, **kwargs):
15
+ return weight_norm(nn.Conv1d(*args, **kwargs))
16
+
17
+
18
+ def WNConvTranspose1d(*args, **kwargs):
19
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
20
+
21
+
22
+ def l2norm(t):
23
+ return F.normalize(t, p=2, dim=-1)
24
+
25
+
26
+ def ema_inplace(moving_avg, new, decay):
27
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
28
+
29
+
30
+ def laplace_smoothing(x, n_categories, eps=1e-5):
31
+ return (x + eps) / (x.sum() + n_categories * eps)
32
+
33
+
34
+ def sample_vectors(samples, num):
35
+ num_samples, device = samples.shape[0], samples.device
36
+
37
+ if num_samples >= num:
38
+ indices = torch.randperm(num_samples, device=device)[:num]
39
+ else:
40
+ indices = torch.randint(0, num_samples, (num,), device=device)
41
+
42
+ return samples[indices]
43
+
44
+
45
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
46
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
47
+
48
+ means = sample_vectors(samples, num_clusters)
49
+
50
+ for _ in range(num_iters):
51
+ if use_cosine_sim:
52
+ dists = samples @ means.t()
53
+ else:
54
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
55
+ means, "c d -> () c d"
56
+ )
57
+ dists = -(diffs**2).sum(dim=-1)
58
+
59
+ buckets = dists.max(dim=-1).indices
60
+ bins = torch.bincount(buckets, minlength=num_clusters)
61
+ zero_mask = bins == 0
62
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
63
+
64
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
65
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
66
+ new_means = new_means / bins_min_clamped[..., None]
67
+
68
+ if use_cosine_sim:
69
+ new_means = l2norm(new_means)
70
+
71
+ means = torch.where(zero_mask[..., None], means, new_means)
72
+
73
+ return means, bins
74
+
75
+
76
+ class EuclideanCodebook(nn.Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ codebook_size,
81
+ kmeans_init=False,
82
+ kmeans_iters=10,
83
+ decay=0.8,
84
+ eps=1e-5,
85
+ threshold_ema_dead_code=2,
86
+ weight_init=False,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.decay = decay
91
+ init_fn = torch.randn if not weight_init else torch.zeros
92
+ embed = init_fn(codebook_size, dim)
93
+
94
+ if weight_init:
95
+ nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
96
+
97
+ self.codebook_size = codebook_size
98
+ self.kmeans_iters = kmeans_iters
99
+ self.eps = eps
100
+ self.threshold_ema_dead_code = threshold_ema_dead_code
101
+
102
+ self.register_buffer(
103
+ "initted", torch.Tensor([not kmeans_init])
104
+ ) # if kmeans_init is True, then initted is False; otherwise, initted is True
105
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
106
+ self.register_buffer("embed", embed)
107
+ self.register_buffer("embed_avg", embed.clone())
108
+
109
+ def init_embed_(self, data):
110
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
111
+ self.embed.data.copy_(embed)
112
+ self.embed_avg.data.copy_(embed)
113
+ self.cluster_size.data.copy_(cluster_size)
114
+ self.initted.data.copy_(torch.Tensor([True]))
115
+
116
+ def replace(self, samples, mask):
117
+ modified_codebook = torch.where(
118
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
119
+ )
120
+ self.embed.data.copy_(modified_codebook)
121
+
122
+ def expire_codes_(self, batch_samples):
123
+ if self.threshold_ema_dead_code == 0:
124
+ return
125
+
126
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
127
+ if not torch.any(expired_codes):
128
+ return
129
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
130
+ self.replace(batch_samples, mask=expired_codes)
131
+
132
+ def forward(self, x):
133
+ shape, dtype = x.shape, x.dtype
134
+ flatten = rearrange(x, "... d -> (...) d")
135
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
136
+
137
+ if not self.initted:
138
+ self.init_embed_(flatten)
139
+
140
+ dist = -(
141
+ flatten.pow(2).sum(1, keepdim=True)
142
+ - 2 * flatten @ embed
143
+ + embed.pow(2).sum(0, keepdim=True)
144
+ )
145
+
146
+ embed_ind = dist.max(dim=-1).indices
147
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
148
+ embed_ind = embed_ind.view(*shape[:-1])
149
+ quantize = F.embedding(embed_ind, self.embed)
150
+
151
+ if self.training:
152
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
153
+ embed_sum = (
154
+ flatten.t() @ embed_onehot
155
+ ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
156
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
157
+ cluster_size = (
158
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
159
+ * self.cluster_size.sum()
160
+ )
161
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
162
+ self.embed.data.copy_(embed_normalized)
163
+ self.expire_codes_(x)
164
+
165
+ return quantize, embed_ind
166
+
167
+ def vq2emb(self, vq):
168
+ quantize = F.embedding(vq, self.embed)
169
+ return quantize
170
+
171
+ def latent2dist(self, x):
172
+ shape, dtype = x.shape, x.dtype
173
+ flatten = rearrange(x, "... d -> (...) d")
174
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
175
+
176
+ if not self.initted:
177
+ self.init_embed_(flatten)
178
+
179
+ dist = -(
180
+ flatten.pow(2).sum(1, keepdim=True)
181
+ - 2 * flatten @ embed
182
+ + embed.pow(2).sum(0, keepdim=True)
183
+ )
184
+
185
+ embed_ind = dist.max(dim=-1).indices
186
+ embed_ind = embed_ind.view(*shape[:-1])
187
+ quantize = F.embedding(embed_ind, self.embed)
188
+
189
+ dist = dist.view(*shape[:-1], -1)
190
+
191
+ return dist, embed_ind, quantize
192
+
193
+
194
+ class SimpleCodebook(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim,
198
+ codebook_size,
199
+ use_l2_normlize=False,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.dim = dim
204
+ self.codebook_size = codebook_size
205
+ self.use_l2_normlize = use_l2_normlize
206
+
207
+ self.embed = nn.Embedding(self.codebook_size, self.dim)
208
+
209
+ def forward(self, x):
210
+ shape, dtype = x.shape, x.dtype
211
+ flatten = rearrange(x, "... d -> (...) d")
212
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
213
+
214
+ if self.use_l2_normlize:
215
+ flatten = F.normalize(flatten)
216
+ embed = F.normalize(embed)
217
+
218
+ dist = -(
219
+ flatten.pow(2).sum(1, keepdim=True)
220
+ - 2 * flatten @ embed
221
+ + embed.pow(2).sum(0, keepdim=True)
222
+ )
223
+
224
+ embed_ind = dist.max(dim=-1).indices
225
+ embed_ind = embed_ind.view(*shape[:-1])
226
+ quantize = F.embedding(embed_ind, self.embed)
227
+
228
+ return quantize, embed_ind
229
+
230
+ def vq2emb(self, vq):
231
+ quantize = F.embedding(vq, self.embed.weight)
232
+ return quantize
233
+
234
+ def latent2dist(self, x):
235
+ shape, dtype = x.shape, x.dtype
236
+ flatten = rearrange(x, "... d -> (...) d")
237
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
238
+
239
+ if self.use_l2_normlize:
240
+ flatten = F.normalize(flatten)
241
+ embed = F.normalize(embed)
242
+
243
+ dist = -(
244
+ flatten.pow(2).sum(1, keepdim=True)
245
+ - 2 * flatten @ embed
246
+ + embed.pow(2).sum(0, keepdim=True)
247
+ )
248
+
249
+ embed_ind = dist.max(dim=-1).indices
250
+ embed_ind = embed_ind.view(*shape[:-1])
251
+ quantize = F.embedding(embed_ind, self.embed)
252
+
253
+ dist = dist.view(*shape[:-1], -1)
254
+
255
+ return dist, embed_ind, quantize
256
+
257
+
258
+ class VectorQuantize(nn.Module):
259
+ """Vector quantization and factorized vecotor quantization implementation
260
+ Args:
261
+ input_dim (int): Dimension of input.
262
+ codebook_size (int): Codebook size.
263
+ codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
264
+ if use codebook_type == "euclidean", otherwise, if you want to use
265
+ factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
266
+ commitment (float): Weight for commitment loss.
267
+ use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
268
+ we suggest use it as True if you want to use factorized vector quantization
269
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
270
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
271
+ decay (float): Decay for exponential moving average over the codebooks.
272
+ epsilon (float): Epsilon value for numerical stability.
273
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
274
+ that have an exponential moving average cluster size less than the specified threshold with
275
+ randomly selected vector from the current batch.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ input_dim,
281
+ codebook_size,
282
+ codebook_dim,
283
+ commitment=0.005,
284
+ codebook_loss_weight=1.0,
285
+ use_l2_normlize=False,
286
+ codebook_type="euclidean", # "euclidean" or "simple"
287
+ kmeans_init=False,
288
+ kmeans_iters=10,
289
+ decay=0.8,
290
+ eps=1e-5,
291
+ threshold_ema_dead_code=2,
292
+ weight_init=False,
293
+ ):
294
+ super().__init__()
295
+ self.input_dim = input_dim
296
+ self.codebook_size = codebook_size
297
+ self.codebook_dim = codebook_dim
298
+ self.commitment = commitment
299
+ self.codebook_loss_weight = codebook_loss_weight
300
+ self.use_l2_normlize = use_l2_normlize
301
+ self.codebook_type = codebook_type
302
+ self.kmeans_init = kmeans_init
303
+ self.kmeans_iters = kmeans_iters
304
+ self.decay = decay
305
+ self.eps = eps
306
+ self.threshold_ema_dead_code = threshold_ema_dead_code
307
+ self.weight_init = weight_init
308
+
309
+ if self.input_dim != self.codebook_dim:
310
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
311
+ self.out_project = WNConv1d(
312
+ self.codebook_dim, self.input_dim, kernel_size=1
313
+ )
314
+
315
+ else:
316
+ self.in_project = nn.Identity()
317
+ self.out_project = nn.Identity()
318
+
319
+ if self.codebook_type == "euclidean":
320
+ self.codebook = EuclideanCodebook(
321
+ self.codebook_dim,
322
+ codebook_size=self.codebook_size,
323
+ kmeans_init=self.kmeans_init,
324
+ kmeans_iters=self.kmeans_iters,
325
+ decay=self.decay,
326
+ eps=self.eps,
327
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
328
+ weight_init=self.weight_init,
329
+ )
330
+ elif self.codebook_type == "simple":
331
+ self.codebook = SimpleCodebook(
332
+ self.codebook_dim,
333
+ codebook_size=self.codebook_size,
334
+ use_l2_normlize=self.use_l2_normlize,
335
+ )
336
+ else:
337
+ raise NotImplementedError(
338
+ f"codebook_type {self.codebook_type} is not implemented!"
339
+ )
340
+
341
+ def forward(self, z):
342
+ """
343
+ Parameters
344
+ ----------
345
+ z: torch.Tensor[B x D x T]
346
+
347
+ Returns
348
+ -------
349
+ z_q: torch.Tensor[B x D x T]
350
+ Quantized continuous representation of input
351
+ commit_loss: Tensor[B]
352
+ Commitment loss to train encoder to predict vectors closer to codebook entries
353
+ codebook_loss: Tensor[B]
354
+ Codebook loss to update the codebook
355
+ indices: torch.Tensor[B x T]
356
+ Codebook indices (quantized discrete representation of input)
357
+ z_e: torch.Tensor[B x D x T]
358
+ Projected latents (continuous representation of input before quantization)
359
+ """
360
+
361
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
362
+ z_e = self.in_project(z)
363
+ z_q, indices = self.decode_latents(z_e)
364
+
365
+ # Compute commitment loss and codebook loss
366
+ if self.training:
367
+ commit_loss = (
368
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
369
+ * self.commitment
370
+ )
371
+ codebook_loss = (
372
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
373
+ * self.codebook_loss_weight
374
+ )
375
+ else:
376
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
377
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
378
+
379
+ z_q = z_e + (z_q - z_e).detach()
380
+
381
+ z_q = self.out_project(z_q)
382
+
383
+ return z_q, commit_loss, codebook_loss, indices, z_e
384
+
385
+ def decode_latents(self, latents):
386
+ encodings = rearrange(latents, "b d t -> b t d")
387
+ z_q, indices = self.codebook(encodings)
388
+ z_q = z_q.transpose(1, 2)
389
+ return z_q, indices
390
+
391
+ def vq2emb(self, vq, out_proj=True):
392
+ emb = self.codebook.vq2emb(vq)
393
+ emb = emb.transpose(1, 2)
394
+ if out_proj:
395
+ emb = self.out_project(emb)
396
+ return emb
397
+
398
+ def latent2dist(self, latents):
399
+ latents = rearrange(latents, "b d t -> b t d")
400
+ dist, embed_ind, quantize = self.codebook.latent2dist(latents)
401
+ return dist, embed_ind, quantize.transpose(1, 2)
models/codec/amphion_codec/vocos.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import scipy
10
+ import torch
11
+ from torch import nn, view_as_real, view_as_complex
12
+ from torch import nn
13
+ from torch.nn.utils import weight_norm, remove_weight_norm
14
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
15
+ import librosa
16
+
17
+
18
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
19
+ """
20
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
21
+
22
+ Args:
23
+ x (Tensor): Input tensor.
24
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
25
+
26
+ Returns:
27
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
28
+ """
29
+ return torch.log(torch.clip(x, min=clip_val))
30
+
31
+
32
+ def symlog(x: torch.Tensor) -> torch.Tensor:
33
+ return torch.sign(x) * torch.log1p(x.abs())
34
+
35
+
36
+ def symexp(x: torch.Tensor) -> torch.Tensor:
37
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
38
+
39
+
40
+ class STFT(nn.Module):
41
+ def __init__(
42
+ self,
43
+ n_fft: int,
44
+ hop_length: int,
45
+ win_length: int,
46
+ center=True,
47
+ ):
48
+ super().__init__()
49
+ self.center = center
50
+ self.n_fft = n_fft
51
+ self.hop_length = hop_length
52
+ self.win_length = win_length
53
+ window = torch.hann_window(win_length)
54
+ self.register_buffer("window", window)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ # x: (B, T * hop_length)
58
+
59
+ if not self.center:
60
+ pad = self.win_length - self.hop_length
61
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
62
+
63
+ stft_spec = torch.stft(
64
+ x,
65
+ self.n_fft,
66
+ hop_length=self.hop_length,
67
+ win_length=self.win_length,
68
+ window=self.window,
69
+ center=self.center,
70
+ return_complex=False,
71
+ ) # (B, n_fft // 2 + 1, T, 2)
72
+
73
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
74
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
75
+
76
+ log_mag = torch.log(
77
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
78
+ ) # (B, n_fft // 2 + 1, T)
79
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
80
+
81
+ return log_mag, phase
82
+
83
+
84
+ class ISTFT(nn.Module):
85
+ """
86
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
87
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
88
+ See issue: https://github.com/pytorch/pytorch/issues/62323
89
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
90
+ The NOLA constraint is met as we trim padded samples anyway.
91
+
92
+ Args:
93
+ n_fft (int): Size of Fourier transform.
94
+ hop_length (int): The distance between neighboring sliding window frames.
95
+ win_length (int): The size of window frame and STFT filter.
96
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
97
+ """
98
+
99
+ def __init__(
100
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
101
+ ):
102
+ super().__init__()
103
+ if padding not in ["center", "same"]:
104
+ raise ValueError("Padding must be 'center' or 'same'.")
105
+ self.padding = padding
106
+ self.n_fft = n_fft
107
+ self.hop_length = hop_length
108
+ self.win_length = win_length
109
+ window = torch.hann_window(win_length)
110
+ self.register_buffer("window", window)
111
+
112
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
115
+
116
+ Args:
117
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
118
+ N is the number of frequency bins, and T is the number of time frames.
119
+
120
+ Returns:
121
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
122
+ """
123
+ if self.padding == "center":
124
+ # Fallback to pytorch native implementation
125
+ return torch.istft(
126
+ spec,
127
+ self.n_fft,
128
+ self.hop_length,
129
+ self.win_length,
130
+ self.window,
131
+ center=True,
132
+ )
133
+ elif self.padding == "same":
134
+ pad = (self.win_length - self.hop_length) // 2
135
+ else:
136
+ raise ValueError("Padding must be 'center' or 'same'.")
137
+
138
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
139
+ B, N, T = spec.shape
140
+
141
+ # Inverse FFT
142
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
143
+ ifft = ifft * self.window[None, :, None]
144
+
145
+ # Overlap and Add
146
+ output_size = (T - 1) * self.hop_length + self.win_length
147
+ y = torch.nn.functional.fold(
148
+ ifft,
149
+ output_size=(1, output_size),
150
+ kernel_size=(1, self.win_length),
151
+ stride=(1, self.hop_length),
152
+ )[:, 0, 0, pad:-pad]
153
+
154
+ # Window envelope
155
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
156
+ window_envelope = torch.nn.functional.fold(
157
+ window_sq,
158
+ output_size=(1, output_size),
159
+ kernel_size=(1, self.win_length),
160
+ stride=(1, self.hop_length),
161
+ ).squeeze()[pad:-pad]
162
+
163
+ # Normalize
164
+ assert (window_envelope > 1e-11).all()
165
+ y = y / window_envelope
166
+
167
+ return y
168
+
169
+
170
+ class MDCT(nn.Module):
171
+ """
172
+ Modified Discrete Cosine Transform (MDCT) module.
173
+
174
+ Args:
175
+ frame_len (int): Length of the MDCT frame.
176
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
177
+ """
178
+
179
+ def __init__(self, frame_len: int, padding: str = "same"):
180
+ super().__init__()
181
+ if padding not in ["center", "same"]:
182
+ raise ValueError("Padding must be 'center' or 'same'.")
183
+ self.padding = padding
184
+ self.frame_len = frame_len
185
+ N = frame_len // 2
186
+ n0 = (N + 1) / 2
187
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
188
+ self.register_buffer("window", window)
189
+
190
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
191
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
192
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
193
+ # https://github.com/pytorch/pytorch/issues/71613
194
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
195
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
196
+
197
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
198
+ """
199
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
200
+
201
+ Args:
202
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
203
+ and T is the length of the audio.
204
+
205
+ Returns:
206
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
207
+ and N is the number of frequency bins.
208
+ """
209
+ if self.padding == "center":
210
+ audio = torch.nn.functional.pad(
211
+ audio, (self.frame_len // 2, self.frame_len // 2)
212
+ )
213
+ elif self.padding == "same":
214
+ # hop_length is 1/2 frame_len
215
+ audio = torch.nn.functional.pad(
216
+ audio, (self.frame_len // 4, self.frame_len // 4)
217
+ )
218
+ else:
219
+ raise ValueError("Padding must be 'center' or 'same'.")
220
+
221
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
222
+ N = self.frame_len // 2
223
+ x = x * self.window.expand(x.shape)
224
+ X = torch.fft.fft(
225
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
226
+ )[..., :N]
227
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
228
+ return torch.real(res) * np.sqrt(2)
229
+
230
+
231
+ class IMDCT(nn.Module):
232
+ """
233
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
234
+
235
+ Args:
236
+ frame_len (int): Length of the MDCT frame.
237
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
238
+ """
239
+
240
+ def __init__(self, frame_len: int, padding: str = "same"):
241
+ super().__init__()
242
+ if padding not in ["center", "same"]:
243
+ raise ValueError("Padding must be 'center' or 'same'.")
244
+ self.padding = padding
245
+ self.frame_len = frame_len
246
+ N = frame_len // 2
247
+ n0 = (N + 1) / 2
248
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
249
+ self.register_buffer("window", window)
250
+
251
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
252
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
253
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
254
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
255
+
256
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
257
+ """
258
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
259
+
260
+ Args:
261
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
262
+ L is the number of frames, and N is the number of frequency bins.
263
+
264
+ Returns:
265
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
266
+ """
267
+ B, L, N = X.shape
268
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
269
+ Y[..., :N] = X
270
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
271
+ y = torch.fft.ifft(
272
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
273
+ )
274
+ y = (
275
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
276
+ * np.sqrt(N)
277
+ * np.sqrt(2)
278
+ )
279
+ result = y * self.window.expand(y.shape)
280
+ output_size = (1, (L + 1) * N)
281
+ audio = torch.nn.functional.fold(
282
+ result.transpose(1, 2),
283
+ output_size=output_size,
284
+ kernel_size=(1, self.frame_len),
285
+ stride=(1, self.frame_len // 2),
286
+ )[:, 0, 0, :]
287
+
288
+ if self.padding == "center":
289
+ pad = self.frame_len // 2
290
+ elif self.padding == "same":
291
+ pad = self.frame_len // 4
292
+ else:
293
+ raise ValueError("Padding must be 'center' or 'same'.")
294
+
295
+ audio = audio[:, pad:-pad]
296
+ return audio
297
+
298
+
299
+ class FourierHead(nn.Module):
300
+ """Base class for inverse fourier modules."""
301
+
302
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
+ """
304
+ Args:
305
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
306
+ L is the sequence length, and H denotes the model dimension.
307
+
308
+ Returns:
309
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
310
+ """
311
+ raise NotImplementedError("Subclasses must implement the forward method.")
312
+
313
+
314
+ class ISTFTHead(FourierHead):
315
+ """
316
+ ISTFT Head module for predicting STFT complex coefficients.
317
+
318
+ Args:
319
+ dim (int): Hidden dimension of the model.
320
+ n_fft (int): Size of Fourier transform.
321
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
322
+ the resolution of the input features.
323
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
324
+ """
325
+
326
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
327
+ super().__init__()
328
+ out_dim = n_fft + 2
329
+ self.out = torch.nn.Linear(dim, out_dim)
330
+ self.istft = ISTFT(
331
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
332
+ )
333
+
334
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
335
+ """
336
+ Forward pass of the ISTFTHead module.
337
+
338
+ Args:
339
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
340
+ L is the sequence length, and H denotes the model dimension.
341
+
342
+ Returns:
343
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
344
+ """
345
+ x = self.out(x).transpose(1, 2)
346
+ mag, p = x.chunk(2, dim=1)
347
+ mag = torch.exp(mag)
348
+ mag = torch.clip(
349
+ mag, max=1e2
350
+ ) # safeguard to prevent excessively large magnitudes
351
+ # wrapping happens here. These two lines produce real and imaginary value
352
+ x = torch.cos(p)
353
+ y = torch.sin(p)
354
+ # recalculating phase here does not produce anything new
355
+ # only costs time
356
+ # phase = torch.atan2(y, x)
357
+ # S = mag * torch.exp(phase * 1j)
358
+ # better directly produce the complex value
359
+ S = mag * (x + 1j * y)
360
+ audio = self.istft(S)
361
+ return audio
362
+
363
+
364
+ class IMDCTSymExpHead(FourierHead):
365
+ """
366
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
367
+
368
+ Args:
369
+ dim (int): Hidden dimension of the model.
370
+ mdct_frame_len (int): Length of the MDCT frame.
371
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
372
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
373
+ based on perceptual scaling. Defaults to None.
374
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
375
+ """
376
+
377
+ def __init__(
378
+ self,
379
+ dim: int,
380
+ mdct_frame_len: int,
381
+ padding: str = "same",
382
+ sample_rate: Optional[int] = None,
383
+ clip_audio: bool = False,
384
+ ):
385
+ super().__init__()
386
+ out_dim = mdct_frame_len // 2
387
+ self.out = nn.Linear(dim, out_dim)
388
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
389
+ self.clip_audio = clip_audio
390
+
391
+ if sample_rate is not None:
392
+ # optionally init the last layer following mel-scale
393
+ m_max = _hz_to_mel(sample_rate // 2)
394
+ m_pts = torch.linspace(0, m_max, out_dim)
395
+ f_pts = _mel_to_hz(m_pts)
396
+ scale = 1 - (f_pts / f_pts.max())
397
+
398
+ with torch.no_grad():
399
+ self.out.weight.mul_(scale.view(-1, 1))
400
+
401
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
402
+ """
403
+ Forward pass of the IMDCTSymExpHead module.
404
+
405
+ Args:
406
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
407
+ L is the sequence length, and H denotes the model dimension.
408
+
409
+ Returns:
410
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
411
+ """
412
+ x = self.out(x)
413
+ x = symexp(x)
414
+ x = torch.clip(
415
+ x, min=-1e2, max=1e2
416
+ ) # safeguard to prevent excessively large magnitudes
417
+ audio = self.imdct(x)
418
+ if self.clip_audio:
419
+ audio = torch.clip(x, min=-1.0, max=1.0)
420
+
421
+ return audio
422
+
423
+
424
+ class IMDCTCosHead(FourierHead):
425
+ """
426
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
427
+
428
+ Args:
429
+ dim (int): Hidden dimension of the model.
430
+ mdct_frame_len (int): Length of the MDCT frame.
431
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
432
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ dim: int,
438
+ mdct_frame_len: int,
439
+ padding: str = "same",
440
+ clip_audio: bool = False,
441
+ ):
442
+ super().__init__()
443
+ self.clip_audio = clip_audio
444
+ self.out = nn.Linear(dim, mdct_frame_len)
445
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
446
+
447
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
448
+ """
449
+ Forward pass of the IMDCTCosHead module.
450
+
451
+ Args:
452
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
453
+ L is the sequence length, and H denotes the model dimension.
454
+
455
+ Returns:
456
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
457
+ """
458
+ x = self.out(x)
459
+ m, p = x.chunk(2, dim=2)
460
+ m = torch.exp(m).clip(
461
+ max=1e2
462
+ ) # safeguard to prevent excessively large magnitudes
463
+ audio = self.imdct(m * torch.cos(p))
464
+ if self.clip_audio:
465
+ audio = torch.clip(x, min=-1.0, max=1.0)
466
+ return audio
467
+
468
+
469
+ class ConvNeXtBlock(nn.Module):
470
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
471
+
472
+ Args:
473
+ dim (int): Number of input channels.
474
+ intermediate_dim (int): Dimensionality of the intermediate layer.
475
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
476
+ Defaults to None.
477
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
478
+ None means non-conditional LayerNorm. Defaults to None.
479
+ """
480
+
481
+ def __init__(
482
+ self,
483
+ dim: int,
484
+ intermediate_dim: int,
485
+ layer_scale_init_value: float,
486
+ adanorm_num_embeddings: Optional[int] = None,
487
+ ):
488
+ super().__init__()
489
+ self.dwconv = nn.Conv1d(
490
+ dim, dim, kernel_size=7, padding=3, groups=dim
491
+ ) # depthwise conv
492
+ self.adanorm = adanorm_num_embeddings is not None
493
+ if adanorm_num_embeddings:
494
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
495
+ else:
496
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
497
+ self.pwconv1 = nn.Linear(
498
+ dim, intermediate_dim
499
+ ) # pointwise/1x1 convs, implemented with linear layers
500
+ self.act = nn.GELU()
501
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
502
+ self.gamma = (
503
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
504
+ if layer_scale_init_value > 0
505
+ else None
506
+ )
507
+
508
+ def forward(
509
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
510
+ ) -> torch.Tensor:
511
+ residual = x
512
+ x = self.dwconv(x)
513
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
514
+ if self.adanorm:
515
+ assert cond_embedding_id is not None
516
+ x = self.norm(x, cond_embedding_id)
517
+ else:
518
+ x = self.norm(x)
519
+ x = self.pwconv1(x)
520
+ x = self.act(x)
521
+ x = self.pwconv2(x)
522
+ if self.gamma is not None:
523
+ x = self.gamma * x
524
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
525
+
526
+ x = residual + x
527
+ return x
528
+
529
+
530
+ class AdaLayerNorm(nn.Module):
531
+ """
532
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
533
+
534
+ Args:
535
+ num_embeddings (int): Number of embeddings.
536
+ embedding_dim (int): Dimension of the embeddings.
537
+ """
538
+
539
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
540
+ super().__init__()
541
+ self.eps = eps
542
+ self.dim = embedding_dim
543
+ self.scale = nn.Embedding(
544
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
545
+ )
546
+ self.shift = nn.Embedding(
547
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
548
+ )
549
+ torch.nn.init.ones_(self.scale.weight)
550
+ torch.nn.init.zeros_(self.shift.weight)
551
+
552
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
553
+ scale = self.scale(cond_embedding_id)
554
+ shift = self.shift(cond_embedding_id)
555
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
556
+ x = x * scale + shift
557
+ return x
558
+
559
+
560
+ class ResBlock1(nn.Module):
561
+ """
562
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
563
+ but without upsampling layers.
564
+
565
+ Args:
566
+ dim (int): Number of input channels.
567
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
568
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
569
+ Defaults to (1, 3, 5).
570
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
571
+ Defaults to 0.1.
572
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
573
+ Defaults to None.
574
+ """
575
+
576
+ def __init__(
577
+ self,
578
+ dim: int,
579
+ kernel_size: int = 3,
580
+ dilation: Tuple[int, int, int] = (1, 3, 5),
581
+ lrelu_slope: float = 0.1,
582
+ layer_scale_init_value: Optional[float] = None,
583
+ ):
584
+ super().__init__()
585
+ self.lrelu_slope = lrelu_slope
586
+ self.convs1 = nn.ModuleList(
587
+ [
588
+ weight_norm(
589
+ nn.Conv1d(
590
+ dim,
591
+ dim,
592
+ kernel_size,
593
+ 1,
594
+ dilation=dilation[0],
595
+ padding=self.get_padding(kernel_size, dilation[0]),
596
+ )
597
+ ),
598
+ weight_norm(
599
+ nn.Conv1d(
600
+ dim,
601
+ dim,
602
+ kernel_size,
603
+ 1,
604
+ dilation=dilation[1],
605
+ padding=self.get_padding(kernel_size, dilation[1]),
606
+ )
607
+ ),
608
+ weight_norm(
609
+ nn.Conv1d(
610
+ dim,
611
+ dim,
612
+ kernel_size,
613
+ 1,
614
+ dilation=dilation[2],
615
+ padding=self.get_padding(kernel_size, dilation[2]),
616
+ )
617
+ ),
618
+ ]
619
+ )
620
+
621
+ self.convs2 = nn.ModuleList(
622
+ [
623
+ weight_norm(
624
+ nn.Conv1d(
625
+ dim,
626
+ dim,
627
+ kernel_size,
628
+ 1,
629
+ dilation=1,
630
+ padding=self.get_padding(kernel_size, 1),
631
+ )
632
+ ),
633
+ weight_norm(
634
+ nn.Conv1d(
635
+ dim,
636
+ dim,
637
+ kernel_size,
638
+ 1,
639
+ dilation=1,
640
+ padding=self.get_padding(kernel_size, 1),
641
+ )
642
+ ),
643
+ weight_norm(
644
+ nn.Conv1d(
645
+ dim,
646
+ dim,
647
+ kernel_size,
648
+ 1,
649
+ dilation=1,
650
+ padding=self.get_padding(kernel_size, 1),
651
+ )
652
+ ),
653
+ ]
654
+ )
655
+
656
+ self.gamma = nn.ParameterList(
657
+ [
658
+ (
659
+ nn.Parameter(
660
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
661
+ )
662
+ if layer_scale_init_value is not None
663
+ else None
664
+ ),
665
+ (
666
+ nn.Parameter(
667
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
668
+ )
669
+ if layer_scale_init_value is not None
670
+ else None
671
+ ),
672
+ (
673
+ nn.Parameter(
674
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
675
+ )
676
+ if layer_scale_init_value is not None
677
+ else None
678
+ ),
679
+ ]
680
+ )
681
+
682
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
683
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
684
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
685
+ xt = c1(xt)
686
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
687
+ xt = c2(xt)
688
+ if gamma is not None:
689
+ xt = gamma * xt
690
+ x = xt + x
691
+ return x
692
+
693
+ def remove_weight_norm(self):
694
+ for l in self.convs1:
695
+ remove_weight_norm(l)
696
+ for l in self.convs2:
697
+ remove_weight_norm(l)
698
+
699
+ @staticmethod
700
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
701
+ return int((kernel_size * dilation - dilation) / 2)
702
+
703
+
704
+ class Backbone(nn.Module):
705
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
706
+
707
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
708
+ """
709
+ Args:
710
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
711
+ C denotes output features, and L is the sequence length.
712
+
713
+ Returns:
714
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
715
+ and H denotes the model dimension.
716
+ """
717
+ raise NotImplementedError("Subclasses must implement the forward method.")
718
+
719
+
720
+ class VocosBackbone(Backbone):
721
+ """
722
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
723
+
724
+ Args:
725
+ input_channels (int): Number of input features channels.
726
+ dim (int): Hidden dimension of the model.
727
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
728
+ num_layers (int): Number of ConvNeXtBlock layers.
729
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
730
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
731
+ None means non-conditional model. Defaults to None.
732
+ """
733
+
734
+ def __init__(
735
+ self,
736
+ input_channels: int,
737
+ dim: int,
738
+ intermediate_dim: int,
739
+ num_layers: int,
740
+ layer_scale_init_value: Optional[float] = None,
741
+ adanorm_num_embeddings: Optional[int] = None,
742
+ ):
743
+ super().__init__()
744
+ self.input_channels = input_channels
745
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
746
+ self.adanorm = adanorm_num_embeddings is not None
747
+ if adanorm_num_embeddings:
748
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
749
+ else:
750
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
751
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
752
+ self.convnext = nn.ModuleList(
753
+ [
754
+ ConvNeXtBlock(
755
+ dim=dim,
756
+ intermediate_dim=intermediate_dim,
757
+ layer_scale_init_value=layer_scale_init_value,
758
+ adanorm_num_embeddings=adanorm_num_embeddings,
759
+ )
760
+ for _ in range(num_layers)
761
+ ]
762
+ )
763
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
764
+ self.apply(self._init_weights)
765
+
766
+ def _init_weights(self, m):
767
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
768
+ nn.init.trunc_normal_(m.weight, std=0.02)
769
+ nn.init.constant_(m.bias, 0)
770
+
771
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
772
+ bandwidth_id = kwargs.get("bandwidth_id", None)
773
+ x = self.embed(x)
774
+ if self.adanorm:
775
+ assert bandwidth_id is not None
776
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
777
+ else:
778
+ x = self.norm(x.transpose(1, 2))
779
+ x = x.transpose(1, 2)
780
+ for conv_block in self.convnext:
781
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
782
+ x = self.final_layer_norm(x.transpose(1, 2))
783
+ return x
784
+
785
+
786
+ class VocosResNetBackbone(Backbone):
787
+ """
788
+ Vocos backbone module built with ResBlocks.
789
+
790
+ Args:
791
+ input_channels (int): Number of input features channels.
792
+ dim (int): Hidden dimension of the model.
793
+ num_blocks (int): Number of ResBlock1 blocks.
794
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
795
+ """
796
+
797
+ def __init__(
798
+ self,
799
+ input_channels,
800
+ dim,
801
+ num_blocks,
802
+ layer_scale_init_value=None,
803
+ ):
804
+ super().__init__()
805
+ self.input_channels = input_channels
806
+ self.embed = weight_norm(
807
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
808
+ )
809
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
810
+ self.resnet = nn.Sequential(
811
+ *[
812
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
813
+ for _ in range(num_blocks)
814
+ ]
815
+ )
816
+
817
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
818
+ x = self.embed(x)
819
+ x = self.resnet(x)
820
+ x = x.transpose(1, 2)
821
+ return x
822
+
823
+
824
+ class Vocos(nn.Module):
825
+ def __init__(
826
+ self,
827
+ input_channels: int = 256,
828
+ dim: int = 384,
829
+ intermediate_dim: int = 1152,
830
+ num_layers: int = 8,
831
+ n_fft: int = 800,
832
+ hop_size: int = 200,
833
+ padding: str = "same",
834
+ adanorm_num_embeddings=None,
835
+ cfg=None,
836
+ ):
837
+ super().__init__()
838
+
839
+ input_channels = (
840
+ cfg.input_channels
841
+ if cfg is not None and hasattr(cfg, "input_channels")
842
+ else input_channels
843
+ )
844
+ dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
845
+ intermediate_dim = (
846
+ cfg.intermediate_dim
847
+ if cfg is not None and hasattr(cfg, "intermediate_dim")
848
+ else intermediate_dim
849
+ )
850
+ num_layers = (
851
+ cfg.num_layers
852
+ if cfg is not None and hasattr(cfg, "num_layers")
853
+ else num_layers
854
+ )
855
+ adanorm_num_embeddings = (
856
+ cfg.adanorm_num_embeddings
857
+ if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
858
+ else adanorm_num_embeddings
859
+ )
860
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
861
+ hop_size = (
862
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
863
+ )
864
+ padding = (
865
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
866
+ )
867
+
868
+ self.backbone = VocosBackbone(
869
+ input_channels=input_channels,
870
+ dim=dim,
871
+ intermediate_dim=intermediate_dim,
872
+ num_layers=num_layers,
873
+ adanorm_num_embeddings=adanorm_num_embeddings,
874
+ )
875
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
876
+
877
+ def forward(self, x):
878
+ x = self.backbone(x)
879
+ x = self.head(x)
880
+
881
+ return x[:, None, :]
models/codec/codec_dataset.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Iterable
7
+ import torch
8
+ import numpy as np
9
+ import torch.utils.data
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from utils.data_utils import *
12
+ from torch.utils.data import ConcatDataset, Dataset
13
+
14
+
15
+ class CodecDataset(torch.utils.data.Dataset):
16
+ def __init__(self, cfg, dataset, is_valid=False):
17
+ """
18
+ Args:
19
+ cfg: config
20
+ dataset: dataset name
21
+ is_valid: whether to use train or valid dataset
22
+ """
23
+ assert isinstance(dataset, str)
24
+
25
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
26
+
27
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
28
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
29
+ self.metadata = self.get_metadata()
30
+
31
+ self.data_root = processed_data_dir
32
+ self.cfg = cfg
33
+
34
+ if cfg.preprocess.use_audio:
35
+ self.utt2audio_path = {}
36
+ for utt_info in self.metadata:
37
+ dataset = utt_info["Dataset"]
38
+ uid = utt_info["Uid"]
39
+ utt = "{}_{}".format(dataset, uid)
40
+
41
+ self.utt2audio_path[utt] = os.path.join(
42
+ cfg.preprocess.processed_dir,
43
+ dataset,
44
+ cfg.preprocess.audio_dir,
45
+ uid + ".npy",
46
+ )
47
+ elif cfg.preprocess.use_label:
48
+ self.utt2label_path = {}
49
+ for utt_info in self.metadata:
50
+ dataset = utt_info["Dataset"]
51
+ uid = utt_info["Uid"]
52
+ utt = "{}_{}".format(dataset, uid)
53
+
54
+ self.utt2label_path[utt] = os.path.join(
55
+ cfg.preprocess.processed_dir,
56
+ dataset,
57
+ cfg.preprocess.label_dir,
58
+ uid + ".npy",
59
+ )
60
+ elif cfg.preprocess.use_one_hot:
61
+ self.utt2one_hot_path = {}
62
+ for utt_info in self.metadata:
63
+ dataset = utt_info["Dataset"]
64
+ uid = utt_info["Uid"]
65
+ utt = "{}_{}".format(dataset, uid)
66
+
67
+ self.utt2one_hot_path[utt] = os.path.join(
68
+ cfg.preprocess.processed_dir,
69
+ dataset,
70
+ cfg.preprocess.one_hot_dir,
71
+ uid + ".npy",
72
+ )
73
+
74
+ if cfg.preprocess.use_mel:
75
+ self.utt2mel_path = {}
76
+ for utt_info in self.metadata:
77
+ dataset = utt_info["Dataset"]
78
+ uid = utt_info["Uid"]
79
+ utt = "{}_{}".format(dataset, uid)
80
+
81
+ self.utt2mel_path[utt] = os.path.join(
82
+ cfg.preprocess.processed_dir,
83
+ dataset,
84
+ cfg.preprocess.mel_dir,
85
+ uid + ".npy",
86
+ )
87
+
88
+ if cfg.preprocess.use_frame_pitch:
89
+ self.utt2frame_pitch_path = {}
90
+ for utt_info in self.metadata:
91
+ dataset = utt_info["Dataset"]
92
+ uid = utt_info["Uid"]
93
+ utt = "{}_{}".format(dataset, uid)
94
+
95
+ self.utt2frame_pitch_path[utt] = os.path.join(
96
+ cfg.preprocess.processed_dir,
97
+ dataset,
98
+ cfg.preprocess.pitch_dir,
99
+ uid + ".npy",
100
+ )
101
+
102
+ if cfg.preprocess.use_uv:
103
+ self.utt2uv_path = {}
104
+ for utt_info in self.metadata:
105
+ dataset = utt_info["Dataset"]
106
+ uid = utt_info["Uid"]
107
+ utt = "{}_{}".format(dataset, uid)
108
+ self.utt2uv_path[utt] = os.path.join(
109
+ cfg.preprocess.processed_dir,
110
+ dataset,
111
+ cfg.preprocess.uv_dir,
112
+ uid + ".npy",
113
+ )
114
+
115
+ if cfg.preprocess.use_amplitude_phase:
116
+ self.utt2logamp_path = {}
117
+ self.utt2pha_path = {}
118
+ self.utt2rea_path = {}
119
+ self.utt2imag_path = {}
120
+ for utt_info in self.metadata:
121
+ dataset = utt_info["Dataset"]
122
+ uid = utt_info["Uid"]
123
+ utt = "{}_{}".format(dataset, uid)
124
+ self.utt2logamp_path[utt] = os.path.join(
125
+ cfg.preprocess.processed_dir,
126
+ dataset,
127
+ cfg.preprocess.log_amplitude_dir,
128
+ uid + ".npy",
129
+ )
130
+ self.utt2pha_path[utt] = os.path.join(
131
+ cfg.preprocess.processed_dir,
132
+ dataset,
133
+ cfg.preprocess.phase_dir,
134
+ uid + ".npy",
135
+ )
136
+ self.utt2rea_path[utt] = os.path.join(
137
+ cfg.preprocess.processed_dir,
138
+ dataset,
139
+ cfg.preprocess.real_dir,
140
+ uid + ".npy",
141
+ )
142
+ self.utt2imag_path[utt] = os.path.join(
143
+ cfg.preprocess.processed_dir,
144
+ dataset,
145
+ cfg.preprocess.imaginary_dir,
146
+ uid + ".npy",
147
+ )
148
+
149
+ def __getitem__(self, index):
150
+ utt_info = self.metadata[index]
151
+
152
+ dataset = utt_info["Dataset"]
153
+ uid = utt_info["Uid"]
154
+ utt = "{}_{}".format(dataset, uid)
155
+
156
+ single_feature = dict()
157
+
158
+ if self.cfg.preprocess.use_mel:
159
+ mel = np.load(self.utt2mel_path[utt])
160
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
161
+
162
+ if "target_len" not in single_feature.keys():
163
+ single_feature["target_len"] = mel.shape[1]
164
+
165
+ single_feature["mel"] = mel
166
+
167
+ if self.cfg.preprocess.use_frame_pitch:
168
+ frame_pitch = np.load(self.utt2frame_pitch_path[utt])
169
+
170
+ if "target_len" not in single_feature.keys():
171
+ single_feature["target_len"] = len(frame_pitch)
172
+
173
+ aligned_frame_pitch = align_length(
174
+ frame_pitch, single_feature["target_len"]
175
+ )
176
+
177
+ single_feature["frame_pitch"] = aligned_frame_pitch
178
+
179
+ if self.cfg.preprocess.use_audio:
180
+ audio = np.load(self.utt2audio_path[utt])
181
+
182
+ single_feature["audio"] = audio
183
+
184
+ return single_feature
185
+
186
+ def get_metadata(self):
187
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
188
+ metadata = json.load(f)
189
+
190
+ return metadata
191
+
192
+ def get_dataset_name(self):
193
+ return self.metadata[0]["Dataset"]
194
+
195
+ def __len__(self):
196
+ return len(self.metadata)
197
+
198
+
199
+ class CodecConcatDataset(ConcatDataset):
200
+ def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
201
+ """Concatenate a series of datasets with their random inference audio merged."""
202
+ super().__init__(datasets)
203
+
204
+ self.cfg = self.datasets[0].cfg
205
+
206
+ self.metadata = []
207
+
208
+ # Merge metadata
209
+ for dataset in self.datasets:
210
+ self.metadata += dataset.metadata
211
+
212
+ # Merge random inference features
213
+ if full_audio_inference:
214
+ self.eval_audios = []
215
+ self.eval_dataset_names = []
216
+ if self.cfg.preprocess.use_mel:
217
+ self.eval_mels = []
218
+ if self.cfg.preprocess.use_frame_pitch:
219
+ self.eval_pitchs = []
220
+ for dataset in self.datasets:
221
+ self.eval_audios.append(dataset.eval_audio)
222
+ self.eval_dataset_names.append(dataset.get_dataset_name())
223
+ if self.cfg.preprocess.use_mel:
224
+ self.eval_mels.append(dataset.eval_mel)
225
+ if self.cfg.preprocess.use_frame_pitch:
226
+ self.eval_pitchs.append(dataset.eval_pitch)
227
+
228
+
229
+ class CodecCollator(object):
230
+ """Zero-pads model inputs and targets based on number of frames per step"""
231
+
232
+ def __init__(self, cfg):
233
+ self.cfg = cfg
234
+
235
+ def __call__(self, batch):
236
+ packed_batch_features = dict()
237
+
238
+ # mel: [b, n_mels, frame]
239
+ # frame_pitch: [b, frame]
240
+ # audios: [b, frame * hop_size]
241
+
242
+ for key in batch[0].keys():
243
+ if key == "target_len":
244
+ packed_batch_features["target_len"] = torch.LongTensor(
245
+ [b["target_len"] for b in batch]
246
+ )
247
+ masks = [
248
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
249
+ ]
250
+ packed_batch_features["mask"] = pad_sequence(
251
+ masks, batch_first=True, padding_value=0
252
+ )
253
+ elif key == "mel":
254
+ values = [torch.from_numpy(b[key]).T for b in batch]
255
+ packed_batch_features[key] = pad_sequence(
256
+ values, batch_first=True, padding_value=0
257
+ )
258
+ else:
259
+ values = [torch.from_numpy(b[key]) for b in batch]
260
+ packed_batch_features[key] = pad_sequence(
261
+ values, batch_first=True, padding_value=0
262
+ )
263
+
264
+ return packed_batch_features
models/codec/codec_inference.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+ import json
9
+ import json5
10
+ import time
11
+ import accelerate
12
+ import random
13
+ import numpy as np
14
+ import shutil
15
+
16
+ from pathlib import Path
17
+ from tqdm import tqdm
18
+ from glob import glob
19
+ from accelerate.logging import get_logger
20
+ from torch.utils.data import DataLoader
21
+
22
+ from models.vocoders.vocoder_dataset import (
23
+ VocoderDataset,
24
+ VocoderCollator,
25
+ VocoderConcatDataset,
26
+ )
27
+
28
+ from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
29
+ from models.vocoders.flow.waveglow import waveglow
30
+ from models.vocoders.diffusion.diffwave import diffwave
31
+ from models.vocoders.autoregressive.wavenet import wavenet
32
+ from models.vocoders.autoregressive.wavernn import wavernn
33
+
34
+ from models.vocoders.gan import gan_vocoder_inference
35
+ from models.vocoders.diffusion import diffusion_vocoder_inference
36
+
37
+ from utils.io import save_audio
38
+
39
+ _vocoders = {
40
+ "diffwave": diffwave.DiffWave,
41
+ "wavernn": wavernn.WaveRNN,
42
+ "wavenet": wavenet.WaveNet,
43
+ "waveglow": waveglow.WaveGlow,
44
+ "nsfhifigan": nsfhifigan.NSFHiFiGAN,
45
+ "bigvgan": bigvgan.BigVGAN,
46
+ "hifigan": hifigan.HiFiGAN,
47
+ "melgan": melgan.MelGAN,
48
+ "apnet": apnet.APNet,
49
+ }
50
+
51
+ # Forward call for generalized Inferencor
52
+ _vocoder_forward_funcs = {
53
+ # "world": world_inference.synthesis_audios,
54
+ # "wavernn": wavernn_inference.synthesis_audios,
55
+ # "wavenet": wavenet_inference.synthesis_audios,
56
+ "diffwave": diffusion_vocoder_inference.vocoder_inference,
57
+ "nsfhifigan": gan_vocoder_inference.vocoder_inference,
58
+ "bigvgan": gan_vocoder_inference.vocoder_inference,
59
+ "melgan": gan_vocoder_inference.vocoder_inference,
60
+ "hifigan": gan_vocoder_inference.vocoder_inference,
61
+ "apnet": gan_vocoder_inference.vocoder_inference,
62
+ }
63
+
64
+ # APIs for other tasks. e.g. SVC, TTS, TTA...
65
+ _vocoder_infer_funcs = {
66
+ # "world": world_inference.synthesis_audios,
67
+ # "wavernn": wavernn_inference.synthesis_audios,
68
+ # "wavenet": wavenet_inference.synthesis_audios,
69
+ "diffwave": diffusion_vocoder_inference.synthesis_audios,
70
+ "nsfhifigan": gan_vocoder_inference.synthesis_audios,
71
+ "bigvgan": gan_vocoder_inference.synthesis_audios,
72
+ "melgan": gan_vocoder_inference.synthesis_audios,
73
+ "hifigan": gan_vocoder_inference.synthesis_audios,
74
+ "apnet": gan_vocoder_inference.synthesis_audios,
75
+ }
76
+
77
+
78
+ class VocoderInference(object):
79
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
80
+ super().__init__()
81
+
82
+ start = time.monotonic_ns()
83
+ self.args = args
84
+ self.cfg = cfg
85
+ self.infer_type = infer_type
86
+
87
+ # Init accelerator
88
+ self.accelerator = accelerate.Accelerator()
89
+ self.accelerator.wait_for_everyone()
90
+
91
+ # Get logger
92
+ with self.accelerator.main_process_first():
93
+ self.logger = get_logger("inference", log_level=args.log_level)
94
+
95
+ # Log some info
96
+ self.logger.info("=" * 56)
97
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
98
+ self.logger.info("=" * 56)
99
+ self.logger.info("\n")
100
+
101
+ self.vocoder_dir = args.vocoder_dir
102
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
103
+
104
+ os.makedirs(args.output_dir, exist_ok=True)
105
+ if os.path.exists(os.path.join(args.output_dir, "pred")):
106
+ shutil.rmtree(os.path.join(args.output_dir, "pred"))
107
+ if os.path.exists(os.path.join(args.output_dir, "gt")):
108
+ shutil.rmtree(os.path.join(args.output_dir, "gt"))
109
+ os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
110
+ os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
111
+
112
+ # Set random seed
113
+ with self.accelerator.main_process_first():
114
+ start = time.monotonic_ns()
115
+ self._set_random_seed(self.cfg.train.random_seed)
116
+ end = time.monotonic_ns()
117
+ self.logger.debug(
118
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
119
+ )
120
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
121
+
122
+ # Setup inference mode
123
+ if self.infer_type == "infer_from_dataset":
124
+ self.cfg.dataset = self.args.infer_datasets
125
+ elif self.infer_type == "infer_from_feature":
126
+ self._build_tmp_dataset_from_feature()
127
+ self.cfg.dataset = ["tmp"]
128
+ elif self.infer_type == "infer_from_audio":
129
+ self._build_tmp_dataset_from_audio()
130
+ self.cfg.dataset = ["tmp"]
131
+
132
+ # Setup data loader
133
+ with self.accelerator.main_process_first():
134
+ self.logger.info("Building dataset...")
135
+ start = time.monotonic_ns()
136
+ self.test_dataloader = self._build_dataloader()
137
+ end = time.monotonic_ns()
138
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
139
+
140
+ # Build model
141
+ with self.accelerator.main_process_first():
142
+ self.logger.info("Building model...")
143
+ start = time.monotonic_ns()
144
+ self.model = self._build_model()
145
+ end = time.monotonic_ns()
146
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
147
+
148
+ # Init with accelerate
149
+ self.logger.info("Initializing accelerate...")
150
+ start = time.monotonic_ns()
151
+ self.accelerator = accelerate.Accelerator()
152
+ (self.model, self.test_dataloader) = self.accelerator.prepare(
153
+ self.model, self.test_dataloader
154
+ )
155
+ end = time.monotonic_ns()
156
+ self.accelerator.wait_for_everyone()
157
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
158
+
159
+ with self.accelerator.main_process_first():
160
+ self.logger.info("Loading checkpoint...")
161
+ start = time.monotonic_ns()
162
+ if os.path.isdir(args.vocoder_dir):
163
+ if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
164
+ self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
165
+ else:
166
+ self._load_model(os.path.join(args.vocoder_dir))
167
+ else:
168
+ self._load_model(os.path.join(args.vocoder_dir))
169
+ end = time.monotonic_ns()
170
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
171
+
172
+ self.model.eval()
173
+ self.accelerator.wait_for_everyone()
174
+
175
+ def _build_tmp_dataset_from_feature(self):
176
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
177
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
178
+
179
+ utts = []
180
+ mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
181
+ for i, mel in enumerate(mels):
182
+ uid = mel.split("/")[-1].split(".")[0]
183
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i}
184
+ utts.append(utt)
185
+
186
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
187
+ with open(
188
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
189
+ ) as f:
190
+ json.dump(utts, f)
191
+
192
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
193
+
194
+ with open(
195
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
196
+ "w",
197
+ ) as f:
198
+ json.dump(meta_info, f)
199
+
200
+ features = glob(os.path.join(self.args.feature_folder, "*"))
201
+ for feature in features:
202
+ feature_name = feature.split("/")[-1]
203
+ if os.path.isfile(feature):
204
+ continue
205
+ shutil.copytree(
206
+ os.path.join(self.args.feature_folder, feature_name),
207
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
208
+ )
209
+
210
+ def _build_tmp_dataset_from_audio(self):
211
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
212
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
213
+
214
+ utts = []
215
+ audios = glob(os.path.join(self.args.audio_folder, "*"))
216
+ for i, audio in enumerate(audios):
217
+ uid = audio.split("/")[-1].split(".")[0]
218
+ utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
219
+ utts.append(utt)
220
+
221
+ os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
222
+ with open(
223
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
224
+ ) as f:
225
+ json.dump(utts, f)
226
+
227
+ meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
228
+
229
+ with open(
230
+ os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
231
+ "w",
232
+ ) as f:
233
+ json.dump(meta_info, f)
234
+
235
+ from processors import acoustic_extractor
236
+
237
+ acoustic_extractor.extract_utt_acoustic_features_serial(
238
+ utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
239
+ )
240
+
241
+ def _build_test_dataset(self):
242
+ return VocoderDataset, VocoderCollator
243
+
244
+ def _build_model(self):
245
+ model = _vocoders[self.cfg.model.generator](self.cfg)
246
+ return model
247
+
248
+ def _build_dataloader(self):
249
+ """Build dataloader which merges a series of datasets."""
250
+ Dataset, Collator = self._build_test_dataset()
251
+
252
+ datasets_list = []
253
+ for dataset in self.cfg.dataset:
254
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
255
+ datasets_list.append(subdataset)
256
+ test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
257
+ test_collate = Collator(self.cfg)
258
+ test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
259
+ test_dataloader = DataLoader(
260
+ test_dataset,
261
+ collate_fn=test_collate,
262
+ num_workers=1,
263
+ batch_size=test_batch_size,
264
+ shuffle=False,
265
+ )
266
+ self.test_batch_size = test_batch_size
267
+ self.test_dataset = test_dataset
268
+ return test_dataloader
269
+
270
+ def _load_model(self, checkpoint_dir, from_multi_gpu=False):
271
+ """Load model from checkpoint. If a folder is given, it will
272
+ load the latest checkpoint in checkpoint_dir. If a path is given
273
+ it will load the checkpoint specified by checkpoint_path.
274
+ **Only use this method after** ``accelerator.prepare()``.
275
+ """
276
+ if os.path.isdir(checkpoint_dir):
277
+ if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
278
+ checkpoint_path = checkpoint_dir
279
+ else:
280
+ # Load the latest accelerator state dicts
281
+ ls = [
282
+ str(i)
283
+ for i in Path(checkpoint_dir).glob("*")
284
+ if not "audio" in str(i)
285
+ ]
286
+ ls.sort(
287
+ key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]),
288
+ reverse=True,
289
+ )
290
+ checkpoint_path = ls[0]
291
+ accelerate.load_checkpoint_and_dispatch(
292
+ self.accelerator.unwrap_model(self.model),
293
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
294
+ )
295
+ return str(checkpoint_path)
296
+ else:
297
+ # Load old .pt checkpoints
298
+ if self.cfg.model.generator in [
299
+ "bigvgan",
300
+ "hifigan",
301
+ "melgan",
302
+ "nsfhifigan",
303
+ ]:
304
+ ckpt = torch.load(
305
+ checkpoint_dir,
306
+ map_location=(
307
+ torch.device("cuda")
308
+ if torch.cuda.is_available()
309
+ else torch.device("cpu")
310
+ ),
311
+ )
312
+ if from_multi_gpu:
313
+ pretrained_generator_dict = ckpt["generator_state_dict"]
314
+ generator_dict = self.model.state_dict()
315
+
316
+ new_generator_dict = {
317
+ k.split("module.")[-1]: v
318
+ for k, v in pretrained_generator_dict.items()
319
+ if (
320
+ k.split("module.")[-1] in generator_dict
321
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
322
+ )
323
+ }
324
+
325
+ generator_dict.update(new_generator_dict)
326
+
327
+ self.model.load_state_dict(generator_dict)
328
+ else:
329
+ self.model.load_state_dict(ckpt["generator_state_dict"])
330
+ else:
331
+ self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
332
+ return str(checkpoint_dir)
333
+
334
+ def inference(self):
335
+ """Inference via batches"""
336
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
337
+ if self.cfg.preprocess.use_frame_pitch:
338
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
339
+ self.cfg,
340
+ self.model,
341
+ batch["mel"].transpose(-1, -2),
342
+ f0s=batch["frame_pitch"].float(),
343
+ device=next(self.model.parameters()).device,
344
+ )
345
+ else:
346
+ audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
347
+ self.cfg,
348
+ self.model,
349
+ batch["mel"].transpose(-1, -2),
350
+ device=next(self.model.parameters()).device,
351
+ )
352
+ audio_ls = audio_pred.chunk(self.test_batch_size)
353
+ audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
354
+ length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
355
+ j = 0
356
+ for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
357
+ l = l.item()
358
+ it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
359
+ it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
360
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
361
+ save_audio(
362
+ os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
363
+ it,
364
+ self.cfg.preprocess.sample_rate,
365
+ )
366
+ save_audio(
367
+ os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
368
+ it_gt,
369
+ self.cfg.preprocess.sample_rate,
370
+ )
371
+ j += 1
372
+
373
+ if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
374
+ shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
375
+
376
+ def _set_random_seed(self, seed):
377
+ """Set random seed for all possible random modules."""
378
+ random.seed(seed)
379
+ np.random.seed(seed)
380
+ torch.random.manual_seed(seed)
381
+
382
+ def _count_parameters(self, model):
383
+ return sum(p.numel() for p in model.parameters())
384
+
385
+ def _dump_cfg(self, path):
386
+ os.makedirs(os.path.dirname(path), exist_ok=True)
387
+ json5.dump(
388
+ self.cfg,
389
+ open(path, "w"),
390
+ indent=4,
391
+ sort_keys=True,
392
+ ensure_ascii=False,
393
+ quote_keys=True,
394
+ )
395
+
396
+
397
+ def load_nnvocoder(
398
+ cfg,
399
+ vocoder_name,
400
+ weights_file,
401
+ from_multi_gpu=False,
402
+ ):
403
+ """Load the specified vocoder.
404
+ cfg: the vocoder config filer.
405
+ weights_file: a folder or a .pt path.
406
+ from_multi_gpu: automatically remove the "module" string in state dicts if "True".
407
+ """
408
+ print("Loading Vocoder from Weights file: {}".format(weights_file))
409
+
410
+ # Build model
411
+ model = _vocoders[vocoder_name](cfg)
412
+ if not os.path.isdir(weights_file):
413
+ # Load from .pt file
414
+ if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
415
+ ckpt = torch.load(
416
+ weights_file,
417
+ map_location=(
418
+ torch.device("cuda")
419
+ if torch.cuda.is_available()
420
+ else torch.device("cpu")
421
+ ),
422
+ )
423
+ if from_multi_gpu:
424
+ pretrained_generator_dict = ckpt["generator_state_dict"]
425
+ generator_dict = model.state_dict()
426
+
427
+ new_generator_dict = {
428
+ k.split("module.")[-1]: v
429
+ for k, v in pretrained_generator_dict.items()
430
+ if (
431
+ k.split("module.")[-1] in generator_dict
432
+ and v.shape == generator_dict[k.split("module.")[-1]].shape
433
+ )
434
+ }
435
+
436
+ generator_dict.update(new_generator_dict)
437
+
438
+ model.load_state_dict(generator_dict)
439
+ else:
440
+ model.load_state_dict(ckpt["generator_state_dict"])
441
+ else:
442
+ model.load_state_dict(torch.load(weights_file)["state_dict"])
443
+ else:
444
+ # Load from accelerator state dict
445
+ weights_file = os.path.join(weights_file, "checkpoint")
446
+ ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
447
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
448
+ checkpoint_path = ls[0]
449
+ accelerator = accelerate.Accelerator()
450
+ model = accelerator.prepare(model)
451
+ accelerator.load_state(checkpoint_path)
452
+
453
+ if torch.cuda.is_available():
454
+ model = model.cuda()
455
+
456
+ model = model.eval()
457
+ return model
458
+
459
+
460
+ def tensorize(data, device, n_samples):
461
+ """
462
+ data: a list of numpy array
463
+ """
464
+ assert type(data) == list
465
+ if n_samples:
466
+ data = data[:n_samples]
467
+ data = [torch.as_tensor(x, device=device) for x in data]
468
+ return data
469
+
470
+
471
+ def synthesis(
472
+ cfg,
473
+ vocoder_weight_file,
474
+ n_samples,
475
+ pred,
476
+ f0s=None,
477
+ batch_size=64,
478
+ fast_inference=False,
479
+ ):
480
+ """Synthesis audios from a given vocoder and series of given features.
481
+ cfg: vocoder config.
482
+ vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
483
+ pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
484
+ """
485
+
486
+ vocoder_name = cfg.model.generator
487
+
488
+ print("Synthesis audios using {} vocoder...".format(vocoder_name))
489
+
490
+ ###### TODO: World Vocoder Refactor ######
491
+ # if vocoder_name == "world":
492
+ # world_inference.synthesis_audios(
493
+ # cfg, dataset_name, split, n_samples, pred, save_dir, tag
494
+ # )
495
+ # return
496
+
497
+ # ====== Loading neural vocoder model ======
498
+ vocoder = load_nnvocoder(
499
+ cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
500
+ )
501
+ device = next(vocoder.parameters()).device
502
+
503
+ # ====== Inference for predicted acoustic features ======
504
+ # pred: (frame_len, n_mels) -> (n_mels, frame_len)
505
+ mels_pred = tensorize([p.T for p in pred], device, n_samples)
506
+ print("For predicted mels, #sample = {}...".format(len(mels_pred)))
507
+ audios_pred = _vocoder_infer_funcs[vocoder_name](
508
+ cfg,
509
+ vocoder,
510
+ mels_pred,
511
+ f0s=f0s,
512
+ batch_size=batch_size,
513
+ fast_inference=fast_inference,
514
+ )
515
+ return audios_pred
models/codec/codec_sampler.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import random
8
+
9
+ from torch.utils.data import ConcatDataset, Dataset
10
+ from torch.utils.data.sampler import (
11
+ BatchSampler,
12
+ RandomSampler,
13
+ Sampler,
14
+ SequentialSampler,
15
+ )
16
+
17
+
18
+ class ScheduledSampler(Sampler):
19
+ """A sampler that samples data from a given concat-dataset.
20
+
21
+ Args:
22
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
23
+ batch_size (int): batch size
24
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
25
+ logger (logging.Logger): logger to print warning message
26
+
27
+ Usage:
28
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
29
+ >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
30
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
31
+ """
32
+
33
+ def __init__(
34
+ self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
35
+ ):
36
+ if not isinstance(concat_dataset, ConcatDataset):
37
+ raise ValueError(
38
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
39
+ type(concat_dataset)
40
+ )
41
+ )
42
+ if not isinstance(batch_size, int):
43
+ raise ValueError(
44
+ "batch_size must be an integer, but got {}".format(type(batch_size))
45
+ )
46
+ if not isinstance(holistic_shuffle, bool):
47
+ raise ValueError(
48
+ "holistic_shuffle must be a boolean, but got {}".format(
49
+ type(holistic_shuffle)
50
+ )
51
+ )
52
+
53
+ self.concat_dataset = concat_dataset
54
+ self.batch_size = batch_size
55
+ self.holistic_shuffle = holistic_shuffle
56
+
57
+ affected_dataset_name = []
58
+ affected_dataset_len = []
59
+ for dataset in concat_dataset.datasets:
60
+ dataset_len = len(dataset)
61
+ dataset_name = dataset.get_dataset_name()
62
+ if dataset_len < batch_size:
63
+ affected_dataset_name.append(dataset_name)
64
+ affected_dataset_len.append(dataset_len)
65
+
66
+ self.type = type
67
+ for dataset_name, dataset_len in zip(
68
+ affected_dataset_name, affected_dataset_len
69
+ ):
70
+ if not type == "valid":
71
+ logger.warning(
72
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
73
+ type, dataset_name, dataset_len, batch_size
74
+ )
75
+ )
76
+
77
+ def __len__(self):
78
+ # the number of batches with drop last
79
+ num_of_batches = sum(
80
+ [
81
+ math.floor(len(dataset) / self.batch_size)
82
+ for dataset in self.concat_dataset.datasets
83
+ ]
84
+ )
85
+ return num_of_batches * self.batch_size
86
+
87
+ def __iter__(self):
88
+ iters = []
89
+ for dataset in self.concat_dataset.datasets:
90
+ iters.append(
91
+ SequentialSampler(dataset).__iter__()
92
+ if self.holistic_shuffle
93
+ else RandomSampler(dataset).__iter__()
94
+ )
95
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
96
+ output_batches = []
97
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
98
+ cur_batch = []
99
+ for idx in iters[dataset_idx]:
100
+ cur_batch.append(idx + init_indices[dataset_idx])
101
+ if len(cur_batch) == self.batch_size:
102
+ output_batches.append(cur_batch)
103
+ cur_batch = []
104
+ if self.type == "valid" and len(cur_batch) > 0:
105
+ output_batches.append(cur_batch)
106
+ cur_batch = []
107
+ # force drop last in training
108
+ random.shuffle(output_batches)
109
+ output_indices = [item for sublist in output_batches for item in sublist]
110
+ return iter(output_indices)
111
+
112
+
113
+ def build_samplers(concat_dataset: Dataset, cfg, logger, type):
114
+ sampler = ScheduledSampler(
115
+ concat_dataset,
116
+ cfg.train.batch_size,
117
+ cfg.train.sampler.holistic_shuffle,
118
+ logger,
119
+ type,
120
+ )
121
+ batch_sampler = BatchSampler(
122
+ sampler,
123
+ cfg.train.batch_size,
124
+ cfg.train.sampler.drop_last if not type == "valid" else False,
125
+ )
126
+ return sampler, batch_sampler
models/codec/codec_trainer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ from pathlib import Path
9
+ import re
10
+
11
+ import accelerate
12
+ import json5
13
+ import numpy as np
14
+ import torch
15
+ from accelerate.utils import ProjectConfiguration
16
+ from torch.utils.data import DataLoader
17
+ from tqdm import tqdm
18
+
19
+ from models.codec.codec_sampler import build_samplers
20
+
21
+
22
+ class CodecTrainer:
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ def _init_accelerator(self):
27
+ """Initialize the accelerator components."""
28
+ self.exp_dir = os.path.join(
29
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
30
+ )
31
+ project_config = ProjectConfiguration(
32
+ project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log")
33
+ )
34
+ self.accelerator = accelerate.Accelerator(
35
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
36
+ log_with=self.cfg.train.tracker,
37
+ project_config=project_config,
38
+ )
39
+ if self.accelerator.is_main_process:
40
+ os.makedirs(project_config.project_dir, exist_ok=True)
41
+ os.makedirs(project_config.logging_dir, exist_ok=True)
42
+ with self.accelerator.main_process_first():
43
+ self.accelerator.init_trackers(self.args.exp_name)
44
+
45
+ def _build_dataset(self):
46
+ pass
47
+
48
+ def _build_criterion(self):
49
+ pass
50
+
51
+ def _build_model(self):
52
+ pass
53
+
54
+ def _build_dataloader(self):
55
+ """Build dataloader which merges a series of datasets."""
56
+ # Build dataset instance for each dataset and combine them by ConcatDataset
57
+ Dataset, Collator = self._build_dataset()
58
+
59
+ # Build train set
60
+ train_dataset = Dataset(self.cfg, self.cfg.dataset, is_valid=False)
61
+ train_collate = Collator(self.cfg)
62
+ sampler = torch.utils.data.distributed.DistributedSampler(
63
+ train_dataset,
64
+ num_replicas=self.accelerator.num_processes,
65
+ rank=self.accelerator.local_process_index,
66
+ shuffle=True,
67
+ seed=self.cfg.train.random_seed,
68
+ )
69
+ train_loader = DataLoader(
70
+ train_dataset,
71
+ batch_size=self.cfg.train.batch_size,
72
+ collate_fn=train_collate,
73
+ sampler=sampler,
74
+ num_workers=self.cfg.train.dataloader.num_worker,
75
+ pin_memory=self.cfg.train.dataloader.pin_memory,
76
+ )
77
+ return train_loader, None
78
+
79
+ def _build_optimizer(self):
80
+ pass
81
+
82
+ def _build_scheduler(self):
83
+ pass
84
+
85
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
86
+ """Load model from checkpoint. If a folder is given, it will
87
+ load the latest checkpoint in checkpoint_dir. If a path is given
88
+ it will load the checkpoint specified by checkpoint_path.
89
+ **Only use this method after** ``accelerator.prepare()``.
90
+ """
91
+ if checkpoint_path is None:
92
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
93
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
94
+ checkpoint_path = ls[0]
95
+ if resume_type == "resume":
96
+ self.accelerator.load_state(checkpoint_path)
97
+ elif resume_type == "finetune":
98
+ accelerate.load_checkpoint_and_dispatch(
99
+ self.accelerator.unwrap_model(self.model),
100
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
101
+ )
102
+ self.logger.info("Load model weights for finetune SUCCESS!")
103
+ else:
104
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
105
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
106
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
107
+ return checkpoint_path
108
+
109
+ def train_loop(self):
110
+ pass
111
+
112
+ def _train_epoch(self):
113
+ pass
114
+
115
+ def _valid_epoch(self):
116
+ pass
117
+
118
+ def _train_step(self):
119
+ pass
120
+
121
+ def _valid_step(self):
122
+ pass
123
+
124
+ def _inference(self):
125
+ pass
126
+
127
+ def _set_random_seed(self, seed):
128
+ """Set random seed for all possible random modules."""
129
+ random.seed(seed)
130
+ np.random.seed(seed)
131
+ torch.random.manual_seed(seed)
132
+
133
+ def _check_nan(self, loss):
134
+ if torch.any(torch.isnan(loss)):
135
+ self.logger.fatal("Fatal Error: NaN!")
136
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
137
+
138
+ def _check_basic_configs(self):
139
+ if self.cfg.train.gradient_accumulation_step <= 0:
140
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
141
+ self.logger.error(
142
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
143
+ )
144
+ self.accelerator.end_training()
145
+ raise ValueError(
146
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
147
+ )
148
+
149
+ def _count_parameters(self):
150
+ pass
151
+
152
+ def _dump_cfg(self, path):
153
+ os.makedirs(os.path.dirname(path), exist_ok=True)
154
+ json5.dump(
155
+ self.cfg,
156
+ open(path, "w"),
157
+ indent=4,
158
+ sort_keys=True,
159
+ ensure_ascii=False,
160
+ quote_keys=True,
161
+ )
162
+
163
+ def _is_valid_pattern(self, directory_name):
164
+ directory_name = str(directory_name)
165
+ pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
166
+ return re.match(pattern, directory_name) is not None
models/codec/facodec/__init__.py ADDED
File without changes
models/codec/facodec/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
models/codec/facodec/alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
models/codec/facodec/alias_free_torch/filter.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ def sinc(x: torch.Tensor):
14
+ """
15
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17
+ """
18
+ return torch.where(
19
+ x == 0,
20
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
21
+ torch.sin(math.pi * x) / math.pi / x,
22
+ )
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ def kaiser_sinc_filter1d(
28
+ cutoff, half_width, kernel_size
29
+ ): # return filter [1,1,kernel_size]
30
+ even = kernel_size % 2 == 0
31
+ half_size = kernel_size // 2
32
+
33
+ # For kaiser window
34
+ delta_f = 4 * half_width
35
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36
+ if A > 50.0:
37
+ beta = 0.1102 * (A - 8.7)
38
+ elif A >= 21.0:
39
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40
+ else:
41
+ beta = 0.0
42
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43
+
44
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45
+ if even:
46
+ time = torch.arange(-half_size, half_size) + 0.5
47
+ else:
48
+ time = torch.arange(kernel_size) - half_size
49
+ if cutoff == 0:
50
+ filter_ = torch.zeros_like(time)
51
+ else:
52
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
54
+ # of the constant component in the input signal.
55
+ filter_ /= filter_.sum()
56
+ filter = filter_.view(1, 1, kernel_size)
57
+
58
+ return filter
59
+
60
+
61
+ class LowPassFilter1d(nn.Module):
62
+ def __init__(
63
+ self,
64
+ cutoff=0.5,
65
+ half_width=0.6,
66
+ stride: int = 1,
67
+ padding: bool = True,
68
+ padding_mode: str = "replicate",
69
+ kernel_size: int = 12,
70
+ ):
71
+ # kernel_size should be even number for stylegan3 setup,
72
+ # in this implementation, odd number is also possible.
73
+ super().__init__()
74
+ if cutoff < -0.0:
75
+ raise ValueError("Minimum cutoff must be larger than zero.")
76
+ if cutoff > 0.5:
77
+ raise ValueError("A cutoff above 0.5 does not make sense.")
78
+ self.kernel_size = kernel_size
79
+ self.even = kernel_size % 2 == 0
80
+ self.pad_left = kernel_size // 2 - int(self.even)
81
+ self.pad_right = kernel_size // 2
82
+ self.stride = stride
83
+ self.padding = padding
84
+ self.padding_mode = padding_mode
85
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86
+ self.register_buffer("filter", filter)
87
+
88
+ # input [B, C, T]
89
+ def forward(self, x):
90
+ _, C, _ = x.shape
91
+
92
+ if self.padding:
93
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95
+
96
+ return out
models/codec/facodec/alias_free_torch/resample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .filter import LowPassFilter1d
6
+ from .filter import kaiser_sinc_filter1d
7
+
8
+
9
+ class UpSample1d(nn.Module):
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left : -self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = (
45
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46
+ )
47
+ self.lowpass = LowPassFilter1d(
48
+ cutoff=0.5 / ratio,
49
+ half_width=0.6 / ratio,
50
+ stride=ratio,
51
+ kernel_size=self.kernel_size,
52
+ )
53
+
54
+ def forward(self, x):
55
+ xx = self.lowpass(x)
56
+
57
+ return xx
models/codec/facodec/facodec_dataset.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import random
8
+
9
+ import numpy as np
10
+
11
+ import torchaudio
12
+ import librosa
13
+ from torch.nn import functional as F
14
+
15
+ from torch.nn.utils.rnn import pad_sequence
16
+ from utils.data_utils import *
17
+ from models.codec.codec_dataset import CodecDataset
18
+
19
+
20
+ class FAcodecDataset(torch.utils.data.Dataset):
21
+ def __init__(self, cfg, dataset, is_valid=False):
22
+ """
23
+ Args:
24
+ cfg: config
25
+ dataset: dataset name
26
+ is_valid: whether to use train or valid dataset
27
+ """
28
+ self.data_root_dir = cfg.dataset
29
+ self.data_list = []
30
+ # walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a
31
+ for root, _, files in os.walk(self.data_root_dir):
32
+ for file in files:
33
+ if file.endswith((".wav", ".mp3", ".opus", ".flac", ".m4a")):
34
+ self.data_list.append(os.path.join(root, file))
35
+ self.sr = cfg.preprocess_params.sr
36
+ self.duration_range = cfg.preprocess_params.duration_range
37
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
38
+ n_mels=cfg.preprocess_params.spect_params.n_mels,
39
+ n_fft=cfg.preprocess_params.spect_params.n_fft,
40
+ win_length=cfg.preprocess_params.spect_params.win_length,
41
+ hop_length=cfg.preprocess_params.spect_params.hop_length,
42
+ )
43
+ self.mean, self.std = -4, 4
44
+
45
+ def preprocess(self, wave):
46
+ wave_tensor = (
47
+ torch.from_numpy(wave).float() if isinstance(wave, np.ndarray) else wave
48
+ )
49
+ mel_tensor = self.to_mel(wave_tensor)
50
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std
51
+ return mel_tensor
52
+
53
+ def __len__(self):
54
+ # return len(self.data_list)
55
+ return len(self.data_list) # return a fixed number for testing
56
+
57
+ def __getitem__(self, index):
58
+ wave, _ = librosa.load(self.data_list[index], sr=self.sr)
59
+ wave = np.random.randn(self.sr * random.randint(*self.duration_range))
60
+ wave = wave / np.max(np.abs(wave))
61
+ mel = self.preprocess(wave).squeeze(0)
62
+ wave = torch.from_numpy(wave).float()
63
+ return wave, mel
64
+
65
+
66
+ class FAcodecCollator(object):
67
+ """Zero-pads model inputs and targets based on number of frames per step"""
68
+
69
+ def __init__(self, cfg):
70
+ self.cfg = cfg
71
+
72
+ def __call__(self, batch):
73
+ # batch[0] = wave, mel, text, f0, speakerid
74
+ batch_size = len(batch)
75
+
76
+ # sort by mel length
77
+ lengths = [b[1].shape[1] for b in batch]
78
+ batch_indexes = np.argsort(lengths)[::-1]
79
+ batch = [batch[bid] for bid in batch_indexes]
80
+
81
+ nmels = batch[0][1].size(0)
82
+ max_mel_length = max([b[1].shape[1] for b in batch])
83
+ max_wave_length = max([b[0].size(0) for b in batch])
84
+
85
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
86
+ waves = torch.zeros((batch_size, max_wave_length)).float()
87
+
88
+ mel_lengths = torch.zeros(batch_size).long()
89
+ wave_lengths = torch.zeros(batch_size).long()
90
+
91
+ for bid, (wave, mel) in enumerate(batch):
92
+ mel_size = mel.size(1)
93
+ mels[bid, :, :mel_size] = mel
94
+ waves[bid, : wave.size(0)] = wave
95
+ mel_lengths[bid] = mel_size
96
+ wave_lengths[bid] = wave.size(0)
97
+
98
+ return waves, mels, wave_lengths, mel_lengths
models/codec/facodec/facodec_inference.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import shutil
7
+ import warnings
8
+ import argparse
9
+ import torch
10
+ import os
11
+ import yaml
12
+
13
+ warnings.simplefilter("ignore")
14
+
15
+ from .modules.commons import *
16
+ import time
17
+
18
+ import torchaudio
19
+ import librosa
20
+ from collections import OrderedDict
21
+
22
+
23
+ class FAcodecInference(object):
24
+ def __init__(self, args=None, cfg=None):
25
+ self.args = args
26
+ self.cfg = cfg
27
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ self.model = self._build_model()
29
+ self._load_checkpoint()
30
+
31
+ def _build_model(self):
32
+ model = build_model(self.cfg.model_params)
33
+ _ = [model[key].to(self.device) for key in model]
34
+ return model
35
+
36
+ def _load_checkpoint(self):
37
+ sd = torch.load(self.args.checkpoint_path, map_location="cpu")
38
+ sd = sd["net"] if "net" in sd else sd
39
+ new_params = dict()
40
+ for key, state_dict in sd.items():
41
+ new_state_dict = OrderedDict()
42
+ for k, v in state_dict.items():
43
+ if k.startswith("module."):
44
+ k = k[7:]
45
+ new_state_dict[k] = v
46
+ new_params[key] = new_state_dict
47
+ for key in new_params:
48
+ if key in self.model:
49
+ self.model[key].load_state_dict(new_params[key])
50
+ _ = [self.model[key].eval() for key in self.model]
51
+
52
+ @torch.no_grad()
53
+ def inference(self, source, output_dir):
54
+ source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
55
+ source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
56
+
57
+ z = self.model.encoder(source_audio[None, ...].to(self.device).float())
58
+ (
59
+ z,
60
+ quantized,
61
+ commitment_loss,
62
+ codebook_loss,
63
+ timbre,
64
+ codes,
65
+ ) = self.model.quantizer(
66
+ z,
67
+ source_audio[None, ...].to(self.device).float(),
68
+ n_c=self.cfg.model_params.n_c_codebooks,
69
+ return_codes=True,
70
+ )
71
+
72
+ full_pred_wave = self.model.decoder(z)
73
+
74
+ os.makedirs(output_dir, exist_ok=True)
75
+ source_name = source.split("/")[-1].split(".")[0]
76
+ torchaudio.save(
77
+ f"{output_dir}/reconstructed_{source_name}.wav",
78
+ full_pred_wave[0].cpu(),
79
+ self.cfg.preprocess_params.sr,
80
+ )
81
+
82
+ print(
83
+ "Reconstructed audio saved as: ",
84
+ f"{output_dir}/reconstructed_{source_name}.wav",
85
+ )
86
+
87
+ return quantized, codes
88
+
89
+ @torch.no_grad()
90
+ def voice_conversion(self, source, reference, output_dir):
91
+ source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
92
+ source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
93
+
94
+ reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0]
95
+ reference_audio = (
96
+ torch.tensor(reference_audio).unsqueeze(0).float().to(self.device)
97
+ )
98
+
99
+ z = self.model.encoder(source_audio[None, ...].to(self.device).float())
100
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
101
+ z,
102
+ source_audio[None, ...].to(self.device).float(),
103
+ n_c=self.cfg.model_params.n_c_codebooks,
104
+ )
105
+
106
+ z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float())
107
+ (
108
+ z_ref,
109
+ quantized_ref,
110
+ commitment_loss_ref,
111
+ codebook_loss_ref,
112
+ timbre_ref,
113
+ ) = self.model.quantizer(
114
+ z_ref,
115
+ reference_audio[None, ...].to(self.device).float(),
116
+ n_c=self.cfg.model_params.n_c_codebooks,
117
+ )
118
+
119
+ z_conv = self.model.quantizer.voice_conversion(
120
+ quantized[0] + quantized[1],
121
+ reference_audio[None, ...].to(self.device).float(),
122
+ )
123
+ full_pred_wave = self.model.decoder(z_conv)
124
+
125
+ os.makedirs(output_dir, exist_ok=True)
126
+ source_name = source.split("/")[-1].split(".")[0]
127
+ reference_name = reference.split("/")[-1].split(".")[0]
128
+ torchaudio.save(
129
+ f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
130
+ full_pred_wave[0].cpu(),
131
+ self.cfg.preprocess_params.sr,
132
+ )
133
+
134
+ print(
135
+ "Voice conversion results saved as: ",
136
+ f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
137
+ )
models/codec/facodec/facodec_trainer.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import random
9
+ from pathlib import Path
10
+ import re
11
+ import glob
12
+
13
+ import accelerate
14
+ import json
15
+ import numpy as np
16
+ import torch
17
+ from accelerate.utils import ProjectConfiguration
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchaudio
24
+
25
+ from accelerate.logging import get_logger
26
+
27
+ from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator
28
+ from models.codec.codec_sampler import build_samplers
29
+ from models.codec.codec_trainer import CodecTrainer
30
+
31
+ from modules.dac.nn.loss import (
32
+ MultiScaleSTFTLoss,
33
+ MelSpectrogramLoss,
34
+ GANLoss,
35
+ L1Loss,
36
+ FocalLoss,
37
+ )
38
+ from audiotools import AudioSignal
39
+
40
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
41
+
42
+ try:
43
+ import nemo.collections.asr as nemo_asr
44
+ except ImportError:
45
+ print(
46
+ "Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING"
47
+ )
48
+ nemo_asr = None
49
+
50
+ from models.codec.facodec.modules.commons import (
51
+ build_model,
52
+ load_checkpoint,
53
+ load_F0_models,
54
+ log_norm,
55
+ )
56
+ from models.codec.facodec.optimizer import build_optimizer
57
+
58
+
59
+ class FAcodecTrainer(CodecTrainer):
60
+ def __init__(self, args, cfg):
61
+ super().__init__()
62
+
63
+ self.args = args
64
+ self.cfg = cfg
65
+
66
+ cfg.exp_name = args.exp_name
67
+
68
+ # Init accelerator
69
+ self._init_accelerator()
70
+ self.accelerator.wait_for_everyone()
71
+
72
+ # Init logger
73
+ with self.accelerator.main_process_first():
74
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
75
+
76
+ self.logger.info("=" * 56)
77
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
78
+ self.logger.info("=" * 56)
79
+ self.logger.info("\n")
80
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
81
+ self.logger.info(f"Experiment name: {args.exp_name}")
82
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
83
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
84
+ if self.accelerator.is_main_process:
85
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
86
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
87
+
88
+ # Init training status
89
+ self.batch_count: int = 0
90
+ self.step: int = 0
91
+ self.epoch: int = 0
92
+
93
+ self.max_epoch = (
94
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
95
+ )
96
+ self.logger.info(
97
+ "Max epoch: {}".format(
98
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
99
+ )
100
+ )
101
+
102
+ # Check potential erorrs
103
+ if self.accelerator.is_main_process:
104
+ self._check_basic_configs()
105
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
106
+ self.checkpoints_path = [
107
+ [] for _ in range(len(self.save_checkpoint_stride))
108
+ ]
109
+ self.run_eval = self.cfg.train.run_eval
110
+
111
+ # Set random seed
112
+ with self.accelerator.main_process_first():
113
+ start = time.monotonic_ns()
114
+ self._set_random_seed(self.cfg.train.random_seed)
115
+ end = time.monotonic_ns()
116
+ self.logger.debug(
117
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
118
+ )
119
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
120
+
121
+ # Build dataloader
122
+ with self.accelerator.main_process_first():
123
+ self.logger.info("Building dataset...")
124
+ start = time.monotonic_ns()
125
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
126
+ end = time.monotonic_ns()
127
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
128
+
129
+ # Build model
130
+ with self.accelerator.main_process_first():
131
+ self.logger.info("Building model...")
132
+ start = time.monotonic_ns()
133
+ self.model = self._build_model()
134
+ end = time.monotonic_ns()
135
+ for _, model in self.model.items():
136
+ self.logger.debug(model)
137
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
138
+ self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
139
+
140
+ # Build optimizers and schedulers
141
+ with self.accelerator.main_process_first():
142
+ self.logger.info("Building optimizer and scheduler...")
143
+ start = time.monotonic_ns()
144
+ self.optimizer = self._build_optimizer()
145
+ end = time.monotonic_ns()
146
+ self.logger.info(
147
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
148
+ )
149
+
150
+ # Build helper models
151
+ with self.accelerator.main_process_first():
152
+ self.logger.info("Building helper models...")
153
+ start = time.monotonic_ns()
154
+ self._built_helper_model()
155
+ end = time.monotonic_ns()
156
+ self.logger.info(
157
+ f"Building helper models done in {(end - start) / 1e6:.2f}ms"
158
+ )
159
+
160
+ # Accelerator preparing
161
+ self.logger.info("Initializing accelerate...")
162
+ start = time.monotonic_ns()
163
+ for k in self.model:
164
+ self.model[k] = self.accelerator.prepare(self.model[k])
165
+ for k, v in self.optimizer.optimizers.items():
166
+ self.optimizer.optimizers[k] = self.accelerator.prepare(
167
+ self.optimizer.optimizers[k]
168
+ )
169
+ self.optimizer.schedulers[k] = self.accelerator.prepare(
170
+ self.optimizer.schedulers[k]
171
+ )
172
+ end = time.monotonic_ns()
173
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
174
+
175
+ # Build criterions
176
+ with self.accelerator.main_process_first():
177
+ self.logger.info("Building criterion...")
178
+ start = time.monotonic_ns()
179
+ self.criterions = self._build_criterion()
180
+ end = time.monotonic_ns()
181
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
182
+
183
+ # Resume checkpoints
184
+ with self.accelerator.main_process_first():
185
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
186
+ if args.resume_type:
187
+ self.logger.info("Resuming from checkpoint...")
188
+ start = time.monotonic_ns()
189
+ ckpt_path = Path(args.checkpoint)
190
+ if self._is_valid_pattern(ckpt_path.parts[-1]):
191
+ ckpt_path = self._load_model(args.checkpoint, args.resume_type)
192
+ else:
193
+ ckpt_path = self._load_model(
194
+ args.checkpoint, resume_type=args.resume_type
195
+ )
196
+ end = time.monotonic_ns()
197
+ self.logger.info(
198
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
199
+ )
200
+ self.checkpoints_path = json.load(
201
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
202
+ )
203
+
204
+ if self.accelerator.is_main_process:
205
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
206
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
207
+
208
+ # Save config
209
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
210
+
211
+ def _build_dataset(self):
212
+ return FAcodecDataset, FAcodecCollator
213
+
214
+ def _build_criterion(self):
215
+ criterions = dict()
216
+ stft_criterion = MultiScaleSTFTLoss()
217
+ mel_criterion = MelSpectrogramLoss(
218
+ n_mels=[5, 10, 20, 40, 80, 160, 320],
219
+ window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
220
+ mel_fmin=[0, 0, 0, 0, 0, 0, 0],
221
+ mel_fmax=[None, None, None, None, None, None, None],
222
+ pow=1.0,
223
+ mag_weight=0.0,
224
+ clamp_eps=1e-5,
225
+ )
226
+ content_criterion = FocalLoss(gamma=2)
227
+ l1_criterion = L1Loss()
228
+ criterions["stft"] = stft_criterion
229
+ criterions["mel"] = mel_criterion
230
+ criterions["l1"] = l1_criterion
231
+ criterions["content"] = content_criterion
232
+
233
+ return criterions
234
+
235
+ def _build_model(self):
236
+ model = build_model(self.cfg.model_params)
237
+ _ = [model[key].to(self.accelerator.device) for key in model]
238
+ return model
239
+
240
+ def _built_helper_model(self):
241
+ device = self.accelerator.device
242
+ self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device)
243
+
244
+ # load model and processor
245
+ self.w2v_processor = Wav2Vec2Processor.from_pretrained(
246
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
247
+ )
248
+ self.w2v_model = Wav2Vec2ForCTC.from_pretrained(
249
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
250
+ ).to(device)
251
+ self.w2v_model.eval()
252
+
253
+ if nemo_asr is None:
254
+ self.speaker_model = None
255
+ else:
256
+ self.speaker_model = (
257
+ nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
258
+ "nvidia/speakerverification_en_titanet_large"
259
+ )
260
+ )
261
+ self.speaker_model = self.speaker_model.to(device)
262
+ self.speaker_model.eval()
263
+
264
+ def _build_optimizer(self):
265
+ scheduler_params = {
266
+ "warmup_steps": self.cfg.loss_params.warmup_steps,
267
+ "base_lr": self.cfg.loss_params.base_lr,
268
+ }
269
+ optimizer = build_optimizer(
270
+ {key: self.model[key] for key in self.model},
271
+ scheduler_params_dict={key: scheduler_params.copy() for key in self.model},
272
+ lr=float(scheduler_params["base_lr"]),
273
+ )
274
+
275
+ return optimizer
276
+
277
+ def train_loop(self):
278
+ """Training process"""
279
+ self.accelerator.wait_for_everyone()
280
+
281
+ # Dump config
282
+ if self.accelerator.is_main_process:
283
+ self._dump_cfg(self.config_save_path)
284
+ _ = [self.model[key].train() for key in self.model]
285
+ self.optimizer.zero_grad()
286
+
287
+ # Sync and start training
288
+ self.accelerator.wait_for_everyone()
289
+ while self.epoch < self.max_epoch:
290
+ self.logger.info("\n")
291
+ self.logger.info("-" * 32)
292
+ self.logger.info("Epoch {}: ".format(self.epoch))
293
+
294
+ # Train and Validate
295
+ train_total_loss, train_losses = self._train_epoch()
296
+ for key, loss in train_losses.items():
297
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
298
+ self.accelerator.log(
299
+ {"Epoch/Train {} Loss".format(key): loss},
300
+ step=self.epoch,
301
+ )
302
+ self.accelerator.log(
303
+ {
304
+ "Epoch/Train Total Loss": train_total_loss,
305
+ },
306
+ step=self.epoch,
307
+ )
308
+
309
+ # Update scheduler
310
+ self.accelerator.wait_for_everyone()
311
+
312
+ # Check save checkpoint interval
313
+ run_eval = False
314
+ if self.accelerator.is_main_process:
315
+ save_checkpoint = False
316
+ for i, num in enumerate(self.save_checkpoint_stride):
317
+ if self.epoch % num == 0:
318
+ save_checkpoint = True
319
+ run_eval |= self.run_eval[i]
320
+
321
+ # Save checkpoints
322
+ self.accelerator.wait_for_everyone()
323
+ if self.accelerator.is_main_process and save_checkpoint:
324
+ print("Saving..")
325
+ state = {
326
+ "net": {key: self.model[key].state_dict() for key in self.model},
327
+ "optimizer": self.optimizer.state_dict(),
328
+ "scheduler": self.optimizer.scheduler_state_dict(),
329
+ "iters": self.step,
330
+ "epoch": self.epoch,
331
+ }
332
+ save_path = os.path.join(
333
+ self.checkpoint_dir,
334
+ "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
335
+ )
336
+ torch.save(state, save_path)
337
+ json.dump(
338
+ self.checkpoints_path,
339
+ open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"),
340
+ ensure_ascii=False,
341
+ indent=4,
342
+ )
343
+
344
+ self.accelerator.wait_for_everyone()
345
+
346
+ self.epoch += 1
347
+
348
+ # Finish training
349
+ self.accelerator.wait_for_everyone()
350
+ if self.accelerator.is_main_process:
351
+ path = os.path.join(
352
+ self.checkpoint_dir,
353
+ "epoch-{:04d}_step-{:07d}".format(
354
+ self.epoch,
355
+ self.step,
356
+ ),
357
+ )
358
+ print("Saving..")
359
+ state = {
360
+ "net": {key: self.model[key].state_dict() for key in self.model},
361
+ "optimizer": self.optimizer.state_dict(),
362
+ "scheduler": self.optimizer.scheduler_state_dict(),
363
+ "iters": self.step,
364
+ "epoch": self.epoch,
365
+ }
366
+ save_path = os.path.join(
367
+ self.checkpoint_dir,
368
+ "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
369
+ )
370
+ torch.save(state, save_path)
371
+
372
+ def _train_epoch(self):
373
+ """Training epoch. Should return average loss of a batch (sample) over
374
+ one epoch. See ``train_loop`` for usage.
375
+ """
376
+ _ = [self.model[key].train() for key in self.model]
377
+
378
+ epoch_losses: dict = {}
379
+ epoch_total_loss: int = 0
380
+
381
+ for batch in tqdm(
382
+ self.train_dataloader,
383
+ desc=f"Training Epoch {self.epoch}",
384
+ unit="batch",
385
+ colour="GREEN",
386
+ leave=False,
387
+ dynamic_ncols=True,
388
+ smoothing=0.04,
389
+ disable=not self.accelerator.is_main_process,
390
+ ):
391
+ # Get losses
392
+ total_loss, losses = self._train_step(batch)
393
+ self.batch_count += 1
394
+
395
+ # Log info
396
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
397
+ self.accelerator.log(
398
+ {
399
+ "Step/Learning Rate": (
400
+ self.optimizer.schedulers["encoder"].get_last_lr()[0]
401
+ if self.step != 0
402
+ else 0
403
+ )
404
+ },
405
+ step=self.step,
406
+ )
407
+ for key, _ in losses.items():
408
+ self.accelerator.log(
409
+ {
410
+ "Step/Train {} Loss".format(key): losses[key],
411
+ },
412
+ step=self.step,
413
+ )
414
+
415
+ if not epoch_losses:
416
+ epoch_losses = losses
417
+ else:
418
+ for key, value in losses.items():
419
+ epoch_losses[key] += value
420
+ epoch_total_loss += total_loss
421
+ self.step += 1
422
+
423
+ # Get and log total losses
424
+ self.accelerator.wait_for_everyone()
425
+ epoch_total_loss = (
426
+ epoch_total_loss
427
+ / len(self.train_dataloader)
428
+ * self.cfg.train.gradient_accumulation_step
429
+ )
430
+ for key in epoch_losses.keys():
431
+ epoch_losses[key] = (
432
+ epoch_losses[key]
433
+ / len(self.train_dataloader)
434
+ * self.cfg.train.gradient_accumulation_step
435
+ )
436
+ return epoch_total_loss, epoch_losses
437
+
438
+ def _train_step(self, data):
439
+ """Training forward step. Should return average loss of a sample over
440
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
441
+ See ``_train_epoch`` for usage.
442
+ """
443
+ # Init losses
444
+ train_losses = {}
445
+ total_loss = 0
446
+
447
+ # Use input feature to get predictions
448
+ data = [b.to(self.accelerator.device, non_blocking=True) for b in data]
449
+ waves, mels, wave_lengths, mel_input_length = data
450
+
451
+ # extract semantic latent with w2v model
452
+ waves_16k = torchaudio.functional.resample(waves, 24000, 16000)
453
+ w2v_input = self.w2v_processor(
454
+ waves_16k, sampling_rate=16000, return_tensors="pt"
455
+ ).input_values.to(self.accelerator.device)
456
+ with torch.no_grad():
457
+ w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits
458
+ predicted_ids = torch.argmax(w2v_outputs, dim=-1)
459
+ phone_ids = (
460
+ F.interpolate(
461
+ predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest"
462
+ )
463
+ .long()
464
+ .squeeze(0)
465
+ )
466
+
467
+ # get clips
468
+ mel_seg_len = min(
469
+ [int(mel_input_length.min().item()), self.cfg.train.max_frame_len]
470
+ )
471
+
472
+ gt_mel_seg = []
473
+ wav_seg = []
474
+ w2v_seg = []
475
+
476
+ for bib in range(len(mel_input_length)):
477
+ mel_length = int(mel_input_length[bib].item())
478
+
479
+ random_start = (
480
+ np.random.randint(0, mel_length - mel_seg_len)
481
+ if mel_length != mel_seg_len
482
+ else 0
483
+ )
484
+ gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len])
485
+
486
+ # w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len])
487
+ w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len])
488
+
489
+ y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300]
490
+
491
+ wav_seg.append(y.to(self.accelerator.device))
492
+
493
+ gt_mel_seg = torch.stack(gt_mel_seg).detach()
494
+
495
+ wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1)
496
+ w2v_seg = torch.stack(w2v_seg).float().detach()
497
+
498
+ with torch.no_grad():
499
+ real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach()
500
+ F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1))
501
+
502
+ # normalize f0
503
+ # Remove unvoiced frames (replace with -1)
504
+ gt_glob_f0s = []
505
+ f0_targets = []
506
+ for bib in range(len(F0_real)):
507
+ voiced_indices = F0_real[bib] > 5.0
508
+ f0_voiced = F0_real[bib][voiced_indices]
509
+
510
+ if len(f0_voiced) != 0:
511
+ # Convert to log scale
512
+ log_f0 = f0_voiced.log2()
513
+
514
+ # Calculate mean and standard deviation
515
+ mean_f0 = log_f0.mean()
516
+ std_f0 = log_f0.std()
517
+
518
+ # Normalize the F0 sequence
519
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
520
+
521
+ # Create the normalized F0 sequence with unvoiced frames
522
+ normalized_sequence = torch.zeros_like(F0_real[bib])
523
+ normalized_sequence[voiced_indices] = normalized_f0
524
+ normalized_sequence[~voiced_indices] = (
525
+ -10
526
+ ) # Assign -10 to unvoiced frames
527
+
528
+ gt_glob_f0s.append(mean_f0)
529
+ else:
530
+ normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0
531
+ gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device))
532
+
533
+ # f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200])
534
+ f0_targets.append(normalized_sequence)
535
+ f0_targets = torch.stack(f0_targets).to(self.accelerator.device)
536
+ # fill nan with -10
537
+ f0_targets[torch.isnan(f0_targets)] = -10.0
538
+ # fill inf with -10
539
+ f0_targets[torch.isinf(f0_targets)] = -10.0
540
+ # if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate
541
+ if self.cfg.preprocess_params.frame_rate != 80:
542
+ f0_targets = F.interpolate(
543
+ f0_targets.unsqueeze(1),
544
+ mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
545
+ mode="nearest",
546
+ ).squeeze(1)
547
+ w2v_seg = F.interpolate(
548
+ w2v_seg,
549
+ mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
550
+ mode="nearest",
551
+ )
552
+
553
+ wav_seg_input = wav_seg
554
+ wav_seg_target = wav_seg
555
+
556
+ z = self.model.encoder(wav_seg_input)
557
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
558
+ z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths
559
+ )
560
+ preds, rev_preds = self.model.fa_predictors(quantized, timbre)
561
+
562
+ pred_wave = self.model.decoder(z)
563
+
564
+ len_diff = wav_seg_target.size(-1) - pred_wave.size(-1)
565
+ if len_diff > 0:
566
+ wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2]
567
+
568
+ # discriminator loss
569
+ d_fake = self.model.discriminator(pred_wave.detach())
570
+ d_real = self.model.discriminator(wav_seg_target)
571
+ loss_d = 0
572
+ for x_fake, x_real in zip(d_fake, d_real):
573
+ loss_d += torch.mean(x_fake[-1] ** 2)
574
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
575
+
576
+ self.optimizer.zero_grad()
577
+ self.accelerator.backward(loss_d)
578
+ grad_norm_d = torch.nn.utils.clip_grad_norm_(
579
+ self.model.discriminator.parameters(), 10.0
580
+ )
581
+ self.optimizer.step("discriminator")
582
+ self.optimizer.scheduler(key="discriminator")
583
+
584
+ # generator loss
585
+ signal = AudioSignal(wav_seg_target, sample_rate=24000)
586
+ recons = AudioSignal(pred_wave, sample_rate=24000)
587
+ stft_loss = self.criterions["stft"](recons, signal)
588
+ mel_loss = self.criterions["mel"](recons, signal)
589
+ waveform_loss = self.criterions["l1"](recons, signal)
590
+
591
+ d_fake = self.model.discriminator(pred_wave)
592
+ d_real = self.model.discriminator(wav_seg_target)
593
+
594
+ loss_g = 0
595
+ for x_fake in d_fake:
596
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
597
+
598
+ loss_feature = 0
599
+
600
+ for i in range(len(d_fake)):
601
+ for j in range(len(d_fake[i]) - 1):
602
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
603
+
604
+ pred_f0, pred_uv = preds["f0"], preds["uv"]
605
+ rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"]
606
+
607
+ common_min_size = min(pred_f0.size(-2), f0_targets.size(-1))
608
+ f0_targets = f0_targets[..., :common_min_size]
609
+ real_norm = real_norm[..., :common_min_size]
610
+
611
+ f0_loss = F.smooth_l1_loss(
612
+ f0_targets, pred_f0.squeeze(-1)[..., :common_min_size]
613
+ )
614
+ uv_loss = F.smooth_l1_loss(
615
+ real_norm, pred_uv.squeeze(-1)[..., :common_min_size]
616
+ )
617
+ rev_f0_loss = (
618
+ F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size])
619
+ if rev_pred_f0 is not None
620
+ else torch.FloatTensor([0]).to(self.accelerator.device)
621
+ )
622
+ rev_uv_loss = (
623
+ F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size])
624
+ if rev_pred_uv is not None
625
+ else torch.FloatTensor([0]).to(self.accelerator.device)
626
+ )
627
+
628
+ tot_f0_loss = f0_loss + rev_f0_loss
629
+ tot_uv_loss = uv_loss + rev_uv_loss
630
+
631
+ pred_content = preds["content"]
632
+ rev_pred_content = rev_preds["rev_content"]
633
+
634
+ target_content_latents = w2v_seg[..., :common_min_size]
635
+
636
+ content_loss = self.criterions["content"](
637
+ pred_content.transpose(1, 2)[..., :common_min_size],
638
+ target_content_latents.long(),
639
+ )
640
+ rev_content_loss = (
641
+ self.criterions["content"](
642
+ rev_pred_content.transpose(1, 2)[..., :common_min_size],
643
+ target_content_latents.long(),
644
+ )
645
+ if rev_pred_content is not None
646
+ else torch.FloatTensor([0]).to(self.accelerator.device)
647
+ )
648
+
649
+ tot_content_loss = content_loss + rev_content_loss
650
+
651
+ if self.speaker_model is not None:
652
+ spk_logits = torch.cat(
653
+ [
654
+ self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1]
655
+ for w16, wl in zip(waves_16k, wave_lengths)
656
+ ],
657
+ dim=0,
658
+ )
659
+ spk_labels = spk_logits.argmax(dim=-1)
660
+ else:
661
+ spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to(
662
+ self.accelerator.device
663
+ )
664
+
665
+ spk_pred_logits = preds["timbre"]
666
+ spk_loss = F.cross_entropy(spk_pred_logits, spk_labels)
667
+ x_spk_pred_logits = rev_preds["x_timbre"]
668
+
669
+ x_spk_loss = (
670
+ F.cross_entropy(x_spk_pred_logits, spk_labels)
671
+ if x_spk_pred_logits is not None
672
+ else torch.FloatTensor([0]).to(self.accelerator.device)
673
+ )
674
+
675
+ tot_spk_loss = spk_loss + x_spk_loss
676
+
677
+ loss_gen_all = (
678
+ mel_loss * 15.0
679
+ + loss_feature * 1.0
680
+ + loss_g * 1.0
681
+ + commitment_loss * 0.25
682
+ + codebook_loss * 1.0
683
+ + tot_f0_loss * 1.0
684
+ + tot_uv_loss * 1.0
685
+ + tot_content_loss * 5.0
686
+ + tot_spk_loss * 5.0
687
+ )
688
+
689
+ self.optimizer.zero_grad()
690
+ self.accelerator.backward(loss_gen_all)
691
+
692
+ with torch.no_grad():
693
+ total_loss = loss_gen_all.item()
694
+ train_losses["stft"] = stft_loss.item()
695
+ train_losses["mel"] = mel_loss.item()
696
+ train_losses["l1"] = waveform_loss.item()
697
+ train_losses["f0"] = f0_loss.item()
698
+ train_losses["uv"] = uv_loss.item()
699
+ train_losses["content"] = content_loss.item()
700
+ train_losses["speaker"] = spk_loss.item()
701
+ train_losses["rev_f0"] = rev_f0_loss.item()
702
+ train_losses["rev_uv"] = rev_uv_loss.item()
703
+ train_losses["rev_content"] = rev_content_loss.item()
704
+ train_losses["rev_speaker"] = x_spk_loss.item()
705
+
706
+ train_losses["feature"] = loss_feature.item()
707
+ train_losses["generator"] = loss_g.item()
708
+ train_losses["commitment"] = commitment_loss.item()
709
+ train_losses["codebook"] = codebook_loss.item()
710
+
711
+ # discriminators
712
+ train_losses["discriminator"] = loss_d.item()
713
+
714
+ return total_loss, train_losses
715
+
716
+ def _inference(self, eval_wave):
717
+ """Inference during training for test audios."""
718
+ z = self.model.encoder(
719
+ eval_wave[None, None, ...].to(self.accelerator.device).float()
720
+ )
721
+ z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
722
+ z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks
723
+ )
724
+ full_pred_wave = self.model.decoder(z)
725
+ return full_pred_wave[0]
726
+
727
+ def _load_model(self, checkpoint_path=None, resume_type="resume"):
728
+ """Load model from checkpoint. If checkpoint_path is None, it will
729
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
730
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
731
+ method after** ``accelerator.prepare()``.
732
+ """
733
+ if resume_type == "resume":
734
+ if checkpoint_path is None:
735
+ available_checkpoints = glob.glob(
736
+ os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth")
737
+ )
738
+ # find the checkpoint that has the highest step number
739
+ latest_checkpoint = max(
740
+ available_checkpoints,
741
+ key=lambda x: int(x.split("_")[-1].split(".")[0]),
742
+ )
743
+ earliest_checkpoint = min(
744
+ available_checkpoints,
745
+ key=lambda x: int(x.split("_")[-1].split(".")[0]),
746
+ )
747
+ # delete the earliest checkpoint
748
+ if (
749
+ earliest_checkpoint != latest_checkpoint
750
+ and self.accelerator.is_main_process
751
+ and len(available_checkpoints) > 4
752
+ ):
753
+ os.remove(earliest_checkpoint)
754
+ print(f"Removed {earliest_checkpoint}")
755
+ else:
756
+ latest_checkpoint = checkpoint_path
757
+
758
+ self.model, self.optimizer, self.epoch, self.step = load_checkpoint(
759
+ self.model,
760
+ self.optimizer,
761
+ latest_checkpoint,
762
+ load_only_params=False,
763
+ ignore_modules=[],
764
+ is_distributed=self.accelerator.num_processes > 1,
765
+ )
766
+
767
+ else:
768
+ raise ValueError("Invalid resume type")
769
+ return checkpoint_path
770
+
771
+ def _count_parameters(self):
772
+ total_num = sum(
773
+ sum(p.numel() for p in self.model[key].parameters()) for key in self.model
774
+ )
775
+ # trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
776
+ return total_num
models/codec/facodec/modules/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
models/codec/facodec/modules/JDC/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
models/codec/facodec/modules/JDC/model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
7
+
8
+ """
9
+ Implementation of model from:
10
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
11
+ Convolutional Recurrent Neural Networks" (2019)
12
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
13
+ """
14
+ import torch
15
+ from torch import nn
16
+
17
+
18
+ class JDCNet(nn.Module):
19
+ """
20
+ Joint Detection and Classification Network model for singing voice melody.
21
+ """
22
+
23
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
24
+ super().__init__()
25
+ self.num_class = num_class
26
+
27
+ # input = (b, 1, 31, 513), b = batch size
28
+ self.conv_block = nn.Sequential(
29
+ nn.Conv2d(
30
+ in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
31
+ ), # out: (b, 64, 31, 513)
32
+ nn.BatchNorm2d(num_features=64),
33
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
34
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
35
+ )
36
+
37
+ # res blocks
38
+ self.res_block1 = ResBlock(
39
+ in_channels=64, out_channels=128
40
+ ) # (b, 128, 31, 128)
41
+ self.res_block2 = ResBlock(
42
+ in_channels=128, out_channels=192
43
+ ) # (b, 192, 31, 32)
44
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
45
+
46
+ # pool block
47
+ self.pool_block = nn.Sequential(
48
+ nn.BatchNorm2d(num_features=256),
49
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
50
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
51
+ nn.Dropout(p=0.2),
52
+ )
53
+
54
+ # maxpool layers (for auxiliary network inputs)
55
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
56
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
57
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
58
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
59
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
60
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
61
+
62
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
63
+ self.detector_conv = nn.Sequential(
64
+ nn.Conv2d(640, 256, 1, bias=False),
65
+ nn.BatchNorm2d(256),
66
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
67
+ nn.Dropout(p=0.2),
68
+ )
69
+
70
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
71
+ self.bilstm_classifier = nn.LSTM(
72
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
73
+ ) # (b, 31, 512)
74
+
75
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
76
+ self.bilstm_detector = nn.LSTM(
77
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
78
+ ) # (b, 31, 512)
79
+
80
+ # input: (b * 31, 512)
81
+ self.classifier = nn.Linear(
82
+ in_features=512, out_features=self.num_class
83
+ ) # (b * 31, num_class)
84
+
85
+ # input: (b * 31, 512)
86
+ self.detector = nn.Linear(
87
+ in_features=512, out_features=2
88
+ ) # (b * 31, 2) - binary classifier
89
+
90
+ # initialize weights
91
+ self.apply(self.init_weights)
92
+
93
+ def get_feature_GAN(self, x):
94
+ seq_len = x.shape[-2]
95
+ x = x.float().transpose(-1, -2)
96
+
97
+ convblock_out = self.conv_block(x)
98
+
99
+ resblock1_out = self.res_block1(convblock_out)
100
+ resblock2_out = self.res_block2(resblock1_out)
101
+ resblock3_out = self.res_block3(resblock2_out)
102
+ poolblock_out = self.pool_block[0](resblock3_out)
103
+ poolblock_out = self.pool_block[1](poolblock_out)
104
+
105
+ return poolblock_out.transpose(-1, -2)
106
+
107
+ def get_feature(self, x):
108
+ seq_len = x.shape[-2]
109
+ x = x.float().transpose(-1, -2)
110
+
111
+ convblock_out = self.conv_block(x)
112
+
113
+ resblock1_out = self.res_block1(convblock_out)
114
+ resblock2_out = self.res_block2(resblock1_out)
115
+ resblock3_out = self.res_block3(resblock2_out)
116
+ poolblock_out = self.pool_block[0](resblock3_out)
117
+ poolblock_out = self.pool_block[1](poolblock_out)
118
+
119
+ return self.pool_block[2](poolblock_out)
120
+
121
+ def forward(self, x):
122
+ """
123
+ Returns:
124
+ classification_prediction, detection_prediction
125
+ sizes: (b, 31, 722), (b, 31, 2)
126
+ """
127
+ ###############################
128
+ # forward pass for classifier #
129
+ ###############################
130
+ seq_len = x.shape[-1]
131
+ x = x.float().transpose(-1, -2)
132
+
133
+ convblock_out = self.conv_block(x)
134
+
135
+ resblock1_out = self.res_block1(convblock_out)
136
+ resblock2_out = self.res_block2(resblock1_out)
137
+ resblock3_out = self.res_block3(resblock2_out)
138
+
139
+ poolblock_out = self.pool_block[0](resblock3_out)
140
+ poolblock_out = self.pool_block[1](poolblock_out)
141
+ GAN_feature = poolblock_out.transpose(-1, -2)
142
+ poolblock_out = self.pool_block[2](poolblock_out)
143
+
144
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
145
+ classifier_out = (
146
+ poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
147
+ )
148
+ classifier_out, _ = self.bilstm_classifier(
149
+ classifier_out
150
+ ) # ignore the hidden states
151
+
152
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
153
+ classifier_out = self.classifier(classifier_out)
154
+ classifier_out = classifier_out.view(
155
+ (-1, seq_len, self.num_class)
156
+ ) # (b, 31, num_class)
157
+
158
+ # sizes: (b, 31, 722), (b, 31, 2)
159
+ # classifier output consists of predicted pitch classes per frame
160
+ # detector output consists of: (isvoice, notvoice) estimates per frame
161
+ return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
162
+
163
+ @staticmethod
164
+ def init_weights(m):
165
+ if isinstance(m, nn.Linear):
166
+ nn.init.kaiming_uniform_(m.weight)
167
+ if m.bias is not None:
168
+ nn.init.constant_(m.bias, 0)
169
+ elif isinstance(m, nn.Conv2d):
170
+ nn.init.xavier_normal_(m.weight)
171
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
172
+ for p in m.parameters():
173
+ if p.data is None:
174
+ continue
175
+
176
+ if len(p.shape) >= 2:
177
+ nn.init.orthogonal_(p.data)
178
+ else:
179
+ nn.init.normal_(p.data)
180
+
181
+
182
+ class ResBlock(nn.Module):
183
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
184
+ super().__init__()
185
+ self.downsample = in_channels != out_channels
186
+
187
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
188
+ self.pre_conv = nn.Sequential(
189
+ nn.BatchNorm2d(num_features=in_channels),
190
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
191
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
192
+ )
193
+
194
+ # conv layers
195
+ self.conv = nn.Sequential(
196
+ nn.Conv2d(
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ kernel_size=3,
200
+ padding=1,
201
+ bias=False,
202
+ ),
203
+ nn.BatchNorm2d(out_channels),
204
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
205
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
206
+ )
207
+
208
+ # 1 x 1 convolution layer to match the feature dimensions
209
+ self.conv1by1 = None
210
+ if self.downsample:
211
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
212
+
213
+ def forward(self, x):
214
+ x = self.pre_conv(x)
215
+ if self.downsample:
216
+ x = self.conv(x) + self.conv1by1(x)
217
+ else:
218
+ x = self.conv(x) + x
219
+ return x
models/codec/facodec/modules/attentions.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
7
+
8
+ import copy
9
+ import math
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from . import commons
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ def __init__(self, channels, eps=1e-5):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.gamma = nn.Parameter(torch.ones(channels))
25
+ self.beta = nn.Parameter(torch.zeros(channels))
26
+
27
+ def forward(self, x):
28
+ x = x.transpose(1, -1)
29
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
30
+ return x.transpose(1, -1)
31
+
32
+
33
+ class Encoder(nn.Module):
34
+ def __init__(
35
+ self,
36
+ hidden_channels,
37
+ filter_channels,
38
+ n_heads,
39
+ n_layers,
40
+ kernel_size=1,
41
+ p_dropout=0.0,
42
+ window_size=4,
43
+ **kwargs
44
+ ):
45
+ super().__init__()
46
+ self.hidden_channels = hidden_channels
47
+ self.filter_channels = filter_channels
48
+ self.n_heads = n_heads
49
+ self.n_layers = n_layers
50
+ self.kernel_size = kernel_size
51
+ self.p_dropout = p_dropout
52
+ self.window_size = window_size
53
+
54
+ self.drop = nn.Dropout(p_dropout)
55
+ self.attn_layers = nn.ModuleList()
56
+ self.norm_layers_1 = nn.ModuleList()
57
+ self.ffn_layers = nn.ModuleList()
58
+ self.norm_layers_2 = nn.ModuleList()
59
+ for i in range(self.n_layers):
60
+ self.attn_layers.append(
61
+ MultiHeadAttention(
62
+ hidden_channels,
63
+ hidden_channels,
64
+ n_heads,
65
+ p_dropout=p_dropout,
66
+ window_size=window_size,
67
+ )
68
+ )
69
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
70
+ self.ffn_layers.append(
71
+ FFN(
72
+ hidden_channels,
73
+ hidden_channels,
74
+ filter_channels,
75
+ kernel_size,
76
+ p_dropout=p_dropout,
77
+ )
78
+ )
79
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
80
+
81
+ def forward(self, x, x_mask):
82
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
83
+ x = x * x_mask
84
+ for i in range(self.n_layers):
85
+ y = self.attn_layers[i](x, x, attn_mask)
86
+ y = self.drop(y)
87
+ x = self.norm_layers_1[i](x + y)
88
+
89
+ y = self.ffn_layers[i](x, x_mask)
90
+ y = self.drop(y)
91
+ x = self.norm_layers_2[i](x + y)
92
+ x = x * x_mask
93
+ return x
94
+
95
+
96
+ class Decoder(nn.Module):
97
+ def __init__(
98
+ self,
99
+ hidden_channels,
100
+ filter_channels,
101
+ n_heads,
102
+ n_layers,
103
+ kernel_size=1,
104
+ p_dropout=0.0,
105
+ proximal_bias=False,
106
+ proximal_init=True,
107
+ **kwargs
108
+ ):
109
+ super().__init__()
110
+ self.hidden_channels = hidden_channels
111
+ self.filter_channels = filter_channels
112
+ self.n_heads = n_heads
113
+ self.n_layers = n_layers
114
+ self.kernel_size = kernel_size
115
+ self.p_dropout = p_dropout
116
+ self.proximal_bias = proximal_bias
117
+ self.proximal_init = proximal_init
118
+
119
+ self.drop = nn.Dropout(p_dropout)
120
+ self.self_attn_layers = nn.ModuleList()
121
+ self.norm_layers_0 = nn.ModuleList()
122
+ self.encdec_attn_layers = nn.ModuleList()
123
+ self.norm_layers_1 = nn.ModuleList()
124
+ self.ffn_layers = nn.ModuleList()
125
+ self.norm_layers_2 = nn.ModuleList()
126
+ for i in range(self.n_layers):
127
+ self.self_attn_layers.append(
128
+ MultiHeadAttention(
129
+ hidden_channels,
130
+ hidden_channels,
131
+ n_heads,
132
+ p_dropout=p_dropout,
133
+ proximal_bias=proximal_bias,
134
+ proximal_init=proximal_init,
135
+ )
136
+ )
137
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
138
+ self.encdec_attn_layers.append(
139
+ MultiHeadAttention(
140
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
141
+ )
142
+ )
143
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
144
+ self.ffn_layers.append(
145
+ FFN(
146
+ hidden_channels,
147
+ hidden_channels,
148
+ filter_channels,
149
+ kernel_size,
150
+ p_dropout=p_dropout,
151
+ causal=True,
152
+ )
153
+ )
154
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
155
+
156
+ def forward(self, x, x_mask, h, h_mask):
157
+ """
158
+ x: decoder input
159
+ h: encoder output
160
+ """
161
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
162
+ device=x.device, dtype=x.dtype
163
+ )
164
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
165
+ x = x * x_mask
166
+ for i in range(self.n_layers):
167
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
168
+ y = self.drop(y)
169
+ x = self.norm_layers_0[i](x + y)
170
+
171
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
172
+ y = self.drop(y)
173
+ x = self.norm_layers_1[i](x + y)
174
+
175
+ y = self.ffn_layers[i](x, x_mask)
176
+ y = self.drop(y)
177
+ x = self.norm_layers_2[i](x + y)
178
+ x = x * x_mask
179
+ return x
180
+
181
+
182
+ class MultiHeadAttention(nn.Module):
183
+ def __init__(
184
+ self,
185
+ channels,
186
+ out_channels,
187
+ n_heads,
188
+ p_dropout=0.0,
189
+ window_size=None,
190
+ heads_share=True,
191
+ block_length=None,
192
+ proximal_bias=False,
193
+ proximal_init=False,
194
+ ):
195
+ super().__init__()
196
+ assert channels % n_heads == 0
197
+
198
+ self.channels = channels
199
+ self.out_channels = out_channels
200
+ self.n_heads = n_heads
201
+ self.p_dropout = p_dropout
202
+ self.window_size = window_size
203
+ self.heads_share = heads_share
204
+ self.block_length = block_length
205
+ self.proximal_bias = proximal_bias
206
+ self.proximal_init = proximal_init
207
+ self.attn = None
208
+
209
+ self.k_channels = channels // n_heads
210
+ self.conv_q = nn.Conv1d(channels, channels, 1)
211
+ self.conv_k = nn.Conv1d(channels, channels, 1)
212
+ self.conv_v = nn.Conv1d(channels, channels, 1)
213
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
214
+ self.drop = nn.Dropout(p_dropout)
215
+
216
+ if window_size is not None:
217
+ n_heads_rel = 1 if heads_share else n_heads
218
+ rel_stddev = self.k_channels**-0.5
219
+ self.emb_rel_k = nn.Parameter(
220
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
221
+ * rel_stddev
222
+ )
223
+ self.emb_rel_v = nn.Parameter(
224
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
225
+ * rel_stddev
226
+ )
227
+
228
+ nn.init.xavier_uniform_(self.conv_q.weight)
229
+ nn.init.xavier_uniform_(self.conv_k.weight)
230
+ nn.init.xavier_uniform_(self.conv_v.weight)
231
+ if proximal_init:
232
+ with torch.no_grad():
233
+ self.conv_k.weight.copy_(self.conv_q.weight)
234
+ self.conv_k.bias.copy_(self.conv_q.bias)
235
+
236
+ def forward(self, x, c, attn_mask=None):
237
+ q = self.conv_q(x)
238
+ k = self.conv_k(c)
239
+ v = self.conv_v(c)
240
+
241
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
242
+
243
+ x = self.conv_o(x)
244
+ return x
245
+
246
+ def attention(self, query, key, value, mask=None):
247
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
248
+ b, d, t_s, t_t = (*key.size(), query.size(2))
249
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
250
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
251
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
252
+
253
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
254
+ if self.window_size is not None:
255
+ assert (
256
+ t_s == t_t
257
+ ), "Relative attention is only available for self-attention."
258
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
259
+ rel_logits = self._matmul_with_relative_keys(
260
+ query / math.sqrt(self.k_channels), key_relative_embeddings
261
+ )
262
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
263
+ scores = scores + scores_local
264
+ if self.proximal_bias:
265
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
266
+ scores = scores + self._attention_bias_proximal(t_s).to(
267
+ device=scores.device, dtype=scores.dtype
268
+ )
269
+ if mask is not None:
270
+ scores = scores.masked_fill(mask == 0, -1e4)
271
+ if self.block_length is not None:
272
+ assert (
273
+ t_s == t_t
274
+ ), "Local attention is only available for self-attention."
275
+ block_mask = (
276
+ torch.ones_like(scores)
277
+ .triu(-self.block_length)
278
+ .tril(self.block_length)
279
+ )
280
+ scores = scores.masked_fill(block_mask == 0, -1e4)
281
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
282
+ p_attn = self.drop(p_attn)
283
+ output = torch.matmul(p_attn, value)
284
+ if self.window_size is not None:
285
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
286
+ value_relative_embeddings = self._get_relative_embeddings(
287
+ self.emb_rel_v, t_s
288
+ )
289
+ output = output + self._matmul_with_relative_values(
290
+ relative_weights, value_relative_embeddings
291
+ )
292
+ output = (
293
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
294
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
295
+ return output, p_attn
296
+
297
+ def _matmul_with_relative_values(self, x, y):
298
+ """
299
+ x: [b, h, l, m]
300
+ y: [h or 1, m, d]
301
+ ret: [b, h, l, d]
302
+ """
303
+ ret = torch.matmul(x, y.unsqueeze(0))
304
+ return ret
305
+
306
+ def _matmul_with_relative_keys(self, x, y):
307
+ """
308
+ x: [b, h, l, d]
309
+ y: [h or 1, m, d]
310
+ ret: [b, h, l, m]
311
+ """
312
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
313
+ return ret
314
+
315
+ def _get_relative_embeddings(self, relative_embeddings, length):
316
+ max_relative_position = 2 * self.window_size + 1
317
+ # Pad first before slice to avoid using cond ops.
318
+ pad_length = max(length - (self.window_size + 1), 0)
319
+ slice_start_position = max((self.window_size + 1) - length, 0)
320
+ slice_end_position = slice_start_position + 2 * length - 1
321
+ if pad_length > 0:
322
+ padded_relative_embeddings = F.pad(
323
+ relative_embeddings,
324
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
325
+ )
326
+ else:
327
+ padded_relative_embeddings = relative_embeddings
328
+ used_relative_embeddings = padded_relative_embeddings[
329
+ :, slice_start_position:slice_end_position
330
+ ]
331
+ return used_relative_embeddings
332
+
333
+ def _relative_position_to_absolute_position(self, x):
334
+ """
335
+ x: [b, h, l, 2*l-1]
336
+ ret: [b, h, l, l]
337
+ """
338
+ batch, heads, length, _ = x.size()
339
+ # Concat columns of pad to shift from relative to absolute indexing.
340
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
341
+
342
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
343
+ x_flat = x.view([batch, heads, length * 2 * length])
344
+ x_flat = F.pad(
345
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
346
+ )
347
+
348
+ # Reshape and slice out the padded elements.
349
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
350
+ :, :, :length, length - 1 :
351
+ ]
352
+ return x_final
353
+
354
+ def _absolute_position_to_relative_position(self, x):
355
+ """
356
+ x: [b, h, l, l]
357
+ ret: [b, h, l, 2*l-1]
358
+ """
359
+ batch, heads, length, _ = x.size()
360
+ # padd along column
361
+ x = F.pad(
362
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
363
+ )
364
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
365
+ # add 0's in the beginning that will skew the elements after reshape
366
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
367
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
368
+ return x_final
369
+
370
+ def _attention_bias_proximal(self, length):
371
+ """Bias for self-attention to encourage attention to close positions.
372
+ Args:
373
+ length: an integer scalar.
374
+ Returns:
375
+ a Tensor with shape [1, 1, length, length]
376
+ """
377
+ r = torch.arange(length, dtype=torch.float32)
378
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
379
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
380
+
381
+
382
+ class FFN(nn.Module):
383
+ def __init__(
384
+ self,
385
+ in_channels,
386
+ out_channels,
387
+ filter_channels,
388
+ kernel_size,
389
+ p_dropout=0.0,
390
+ activation=None,
391
+ causal=False,
392
+ ):
393
+ super().__init__()
394
+ self.in_channels = in_channels
395
+ self.out_channels = out_channels
396
+ self.filter_channels = filter_channels
397
+ self.kernel_size = kernel_size
398
+ self.p_dropout = p_dropout
399
+ self.activation = activation
400
+ self.causal = causal
401
+
402
+ if causal:
403
+ self.padding = self._causal_padding
404
+ else:
405
+ self.padding = self._same_padding
406
+
407
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
408
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
409
+ self.drop = nn.Dropout(p_dropout)
410
+
411
+ def forward(self, x, x_mask):
412
+ x = self.conv_1(self.padding(x * x_mask))
413
+ if self.activation == "gelu":
414
+ x = x * torch.sigmoid(1.702 * x)
415
+ else:
416
+ x = torch.relu(x)
417
+ x = self.drop(x)
418
+ x = self.conv_2(self.padding(x * x_mask))
419
+ return x * x_mask
420
+
421
+ def _causal_padding(self, x):
422
+ if self.kernel_size == 1:
423
+ return x
424
+ pad_l = self.kernel_size - 1
425
+ pad_r = 0
426
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
427
+ x = F.pad(x, commons.convert_pad_shape(padding))
428
+ return x
429
+
430
+ def _same_padding(self, x):
431
+ if self.kernel_size == 1:
432
+ return x
433
+ pad_l = (self.kernel_size - 1) // 2
434
+ pad_r = self.kernel_size // 2
435
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
436
+ x = F.pad(x, commons.convert_pad_shape(padding))
437
+ return x
models/codec/facodec/modules/commons.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import math
8
+ import os.path
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from munch import Munch
15
+ import json
16
+
17
+
18
+ class AttrDict(dict):
19
+ def __init__(self, *args, **kwargs):
20
+ super(AttrDict, self).__init__(*args, **kwargs)
21
+ self.__dict__ = self
22
+
23
+
24
+ def init_weights(m, mean=0.0, std=0.01):
25
+ classname = m.__class__.__name__
26
+ if classname.find("Conv") != -1:
27
+ m.weight.data.normal_(mean, std)
28
+
29
+
30
+ def get_padding(kernel_size, dilation=1):
31
+ return int((kernel_size * dilation - dilation) / 2)
32
+
33
+
34
+ def convert_pad_shape(pad_shape):
35
+ l = pad_shape[::-1]
36
+ pad_shape = [item for sublist in l for item in sublist]
37
+ return pad_shape
38
+
39
+
40
+ def intersperse(lst, item):
41
+ result = [item] * (len(lst) * 2 + 1)
42
+ result[1::2] = lst
43
+ return result
44
+
45
+
46
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
47
+ """KL(P||Q)"""
48
+ kl = (logs_q - logs_p) - 0.5
49
+ kl += (
50
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
51
+ )
52
+ return kl
53
+
54
+
55
+ def rand_gumbel(shape):
56
+ """Sample from the Gumbel distribution, protect from overflows."""
57
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
58
+ return -torch.log(-torch.log(uniform_samples))
59
+
60
+
61
+ def rand_gumbel_like(x):
62
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
63
+ return g
64
+
65
+
66
+ def slice_segments(x, ids_str, segment_size=4):
67
+ ret = torch.zeros_like(x[:, :, :segment_size])
68
+ for i in range(x.size(0)):
69
+ idx_str = ids_str[i]
70
+ idx_end = idx_str + segment_size
71
+ ret[i] = x[i, :, idx_str:idx_end]
72
+ return ret
73
+
74
+
75
+ def slice_segments_audio(x, ids_str, segment_size=4):
76
+ ret = torch.zeros_like(x[:, :segment_size])
77
+ for i in range(x.size(0)):
78
+ idx_str = ids_str[i]
79
+ idx_end = idx_str + segment_size
80
+ ret[i] = x[i, idx_str:idx_end]
81
+ return ret
82
+
83
+
84
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
85
+ b, d, t = x.size()
86
+ if x_lengths is None:
87
+ x_lengths = t
88
+ ids_str_max = x_lengths - segment_size + 1
89
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
90
+ dtype=torch.long
91
+ )
92
+ ret = slice_segments(x, ids_str, segment_size)
93
+ return ret, ids_str
94
+
95
+
96
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
97
+ position = torch.arange(length, dtype=torch.float)
98
+ num_timescales = channels // 2
99
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
100
+ num_timescales - 1
101
+ )
102
+ inv_timescales = min_timescale * torch.exp(
103
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
104
+ )
105
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
106
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
107
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
108
+ signal = signal.view(1, channels, length)
109
+ return signal
110
+
111
+
112
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
113
+ b, channels, length = x.size()
114
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
115
+ return x + signal.to(dtype=x.dtype, device=x.device)
116
+
117
+
118
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
119
+ b, channels, length = x.size()
120
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
121
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
122
+
123
+
124
+ def subsequent_mask(length):
125
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
126
+ return mask
127
+
128
+
129
+ @torch.jit.script
130
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
131
+ n_channels_int = n_channels[0]
132
+ in_act = input_a + input_b
133
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
134
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
135
+ acts = t_act * s_act
136
+ return acts
137
+
138
+
139
+ def convert_pad_shape(pad_shape):
140
+ l = pad_shape[::-1]
141
+ pad_shape = [item for sublist in l for item in sublist]
142
+ return pad_shape
143
+
144
+
145
+ def shift_1d(x):
146
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
147
+ return x
148
+
149
+
150
+ def sequence_mask(length, max_length=None):
151
+ if max_length is None:
152
+ max_length = length.max()
153
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
154
+ return x.unsqueeze(0) < length.unsqueeze(1)
155
+
156
+
157
+ def generate_path(duration, mask):
158
+ """
159
+ duration: [b, 1, t_x]
160
+ mask: [b, 1, t_y, t_x]
161
+ """
162
+ device = duration.device
163
+
164
+ b, _, t_y, t_x = mask.shape
165
+ cum_duration = torch.cumsum(duration, -1)
166
+
167
+ cum_duration_flat = cum_duration.view(b * t_x)
168
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
169
+ path = path.view(b, t_x, t_y)
170
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
171
+ path = path.unsqueeze(1).transpose(2, 3) * mask
172
+ return path
173
+
174
+
175
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
176
+ if isinstance(parameters, torch.Tensor):
177
+ parameters = [parameters]
178
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
179
+ norm_type = float(norm_type)
180
+ if clip_value is not None:
181
+ clip_value = float(clip_value)
182
+
183
+ total_norm = 0
184
+ for p in parameters:
185
+ param_norm = p.grad.data.norm(norm_type)
186
+ total_norm += param_norm.item() ** norm_type
187
+ if clip_value is not None:
188
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
189
+ total_norm = total_norm ** (1.0 / norm_type)
190
+ return total_norm
191
+
192
+
193
+ def log_norm(x, mean=-4, std=4, dim=2):
194
+ """
195
+ normalized log mel -> mel -> norm -> log(norm)
196
+ """
197
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
198
+ return x
199
+
200
+
201
+ from huggingface_hub import hf_hub_download
202
+
203
+
204
+ def load_F0_models(path):
205
+ # load F0 model
206
+ from .JDC.model import JDCNet
207
+
208
+ F0_model = JDCNet(num_class=1, seq_len=192)
209
+ if not os.path.exists(path):
210
+ path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7")
211
+ params = torch.load(path, map_location="cpu")["net"]
212
+ F0_model.load_state_dict(params)
213
+ _ = F0_model.train()
214
+
215
+ return F0_model
216
+
217
+
218
+ # Generators
219
+ from modules.dac.model.dac import Encoder, Decoder
220
+ from .quantize import FAquantizer, FApredictors
221
+
222
+ # Discriminators
223
+ from modules.dac.model.discriminator import Discriminator
224
+
225
+
226
+ def build_model(args):
227
+ encoder = Encoder(
228
+ d_model=args.DAC.encoder_dim,
229
+ strides=args.DAC.encoder_rates,
230
+ d_latent=1024,
231
+ causal=args.causal,
232
+ lstm=args.lstm,
233
+ )
234
+
235
+ quantizer = FAquantizer(
236
+ in_dim=1024,
237
+ n_p_codebooks=1,
238
+ n_c_codebooks=args.n_c_codebooks,
239
+ n_t_codebooks=2,
240
+ n_r_codebooks=3,
241
+ codebook_size=1024,
242
+ codebook_dim=8,
243
+ quantizer_dropout=0.5,
244
+ causal=args.causal,
245
+ separate_prosody_encoder=args.separate_prosody_encoder,
246
+ timbre_norm=args.timbre_norm,
247
+ )
248
+
249
+ fa_predictors = FApredictors(
250
+ in_dim=1024,
251
+ use_gr_content_f0=args.use_gr_content_f0,
252
+ use_gr_prosody_phone=args.use_gr_prosody_phone,
253
+ use_gr_residual_f0=True,
254
+ use_gr_residual_phone=True,
255
+ use_gr_timbre_content=True,
256
+ use_gr_timbre_prosody=args.use_gr_timbre_prosody,
257
+ use_gr_x_timbre=True,
258
+ norm_f0=args.norm_f0,
259
+ timbre_norm=args.timbre_norm,
260
+ use_gr_content_global_f0=args.use_gr_content_global_f0,
261
+ )
262
+
263
+ decoder = Decoder(
264
+ input_channel=1024,
265
+ channels=args.DAC.decoder_dim,
266
+ rates=args.DAC.decoder_rates,
267
+ causal=args.causal,
268
+ lstm=args.lstm,
269
+ )
270
+
271
+ discriminator = Discriminator(
272
+ rates=[],
273
+ periods=[2, 3, 5, 7, 11],
274
+ fft_sizes=[2048, 1024, 512],
275
+ sample_rate=args.DAC.sr,
276
+ bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
277
+ )
278
+
279
+ nets = Munch(
280
+ encoder=encoder,
281
+ quantizer=quantizer,
282
+ decoder=decoder,
283
+ discriminator=discriminator,
284
+ fa_predictors=fa_predictors,
285
+ )
286
+
287
+ return nets
288
+
289
+
290
+ def load_checkpoint(
291
+ model,
292
+ optimizer,
293
+ path,
294
+ load_only_params=True,
295
+ ignore_modules=[],
296
+ is_distributed=False,
297
+ ):
298
+ state = torch.load(path, map_location="cpu")
299
+ params = state["net"]
300
+ for key in model:
301
+ if key in params and key not in ignore_modules:
302
+ if not is_distributed:
303
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
304
+ for k in list(params[key].keys()):
305
+ if k.startswith("module."):
306
+ params[key][k[len("module.") :]] = params[key][k]
307
+ del params[key][k]
308
+ print("%s loaded" % key)
309
+ model[key].load_state_dict(params[key], strict=True)
310
+ _ = [model[key].eval() for key in model]
311
+
312
+ if not load_only_params:
313
+ epoch = state["epoch"] + 1
314
+ iters = state["iters"]
315
+ optimizer.load_state_dict(state["optimizer"])
316
+ optimizer.load_scheduler_state_dict(state["scheduler"])
317
+
318
+ else:
319
+ epoch = state["epoch"] + 1
320
+ iters = state["iters"]
321
+
322
+ return model, optimizer, epoch, iters
323
+
324
+
325
+ def recursive_munch(d):
326
+ if isinstance(d, dict):
327
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
328
+ elif isinstance(d, list):
329
+ return [recursive_munch(v) for v in d]
330
+ else:
331
+ return d
models/codec/facodec/modules/gradient_reversal.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.autograd import Function
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class GradientReversal(Function):
12
+ @staticmethod
13
+ def forward(ctx, x, alpha):
14
+ ctx.save_for_backward(x, alpha)
15
+ return x
16
+
17
+ @staticmethod
18
+ def backward(ctx, grad_output):
19
+ grad_input = None
20
+ _, alpha = ctx.saved_tensors
21
+ if ctx.needs_input_grad[0]:
22
+ grad_input = -alpha * grad_output
23
+ return grad_input, None
24
+
25
+
26
+ revgrad = GradientReversal.apply
27
+
28
+
29
+ class GradientReversal(nn.Module):
30
+ def __init__(self, alpha):
31
+ super().__init__()
32
+ self.alpha = torch.tensor(alpha, requires_grad=False)
33
+
34
+ def forward(self, x):
35
+ return revgrad(x, self.alpha)
models/codec/facodec/modules/layers.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ from torch import nn
9
+ from typing import Optional, Any
10
+ from torch import Tensor
11
+ import torch.nn.functional as F
12
+ import torchaudio
13
+ import torchaudio.functional as audio_F
14
+
15
+ import random
16
+
17
+ random.seed(0)
18
+
19
+
20
+ def _get_activation_fn(activ):
21
+ if activ == "relu":
22
+ return nn.ReLU()
23
+ elif activ == "lrelu":
24
+ return nn.LeakyReLU(0.2)
25
+ elif activ == "swish":
26
+ return lambda x: x * torch.sigmoid(x)
27
+ else:
28
+ raise RuntimeError(
29
+ "Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
30
+ )
31
+
32
+
33
+ class LinearNorm(torch.nn.Module):
34
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
35
+ super(LinearNorm, self).__init__()
36
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
37
+
38
+ torch.nn.init.xavier_uniform_(
39
+ self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
40
+ )
41
+
42
+ def forward(self, x):
43
+ return self.linear_layer(x)
44
+
45
+
46
+ class ConvNorm(torch.nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_channels,
50
+ out_channels,
51
+ kernel_size=1,
52
+ stride=1,
53
+ padding=None,
54
+ dilation=1,
55
+ bias=True,
56
+ w_init_gain="linear",
57
+ param=None,
58
+ ):
59
+ super(ConvNorm, self).__init__()
60
+ if padding is None:
61
+ assert kernel_size % 2 == 1
62
+ padding = int(dilation * (kernel_size - 1) / 2)
63
+
64
+ self.conv = torch.nn.Conv1d(
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size=kernel_size,
68
+ stride=stride,
69
+ padding=padding,
70
+ dilation=dilation,
71
+ bias=bias,
72
+ )
73
+
74
+ torch.nn.init.xavier_uniform_(
75
+ self.conv.weight,
76
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
77
+ )
78
+
79
+ def forward(self, signal):
80
+ conv_signal = self.conv(signal)
81
+ return conv_signal
82
+
83
+
84
+ class CausualConv(nn.Module):
85
+ def __init__(
86
+ self,
87
+ in_channels,
88
+ out_channels,
89
+ kernel_size=1,
90
+ stride=1,
91
+ padding=1,
92
+ dilation=1,
93
+ bias=True,
94
+ w_init_gain="linear",
95
+ param=None,
96
+ ):
97
+ super(CausualConv, self).__init__()
98
+ if padding is None:
99
+ assert kernel_size % 2 == 1
100
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
101
+ else:
102
+ self.padding = padding * 2
103
+ self.conv = nn.Conv1d(
104
+ in_channels,
105
+ out_channels,
106
+ kernel_size=kernel_size,
107
+ stride=stride,
108
+ padding=self.padding,
109
+ dilation=dilation,
110
+ bias=bias,
111
+ )
112
+
113
+ torch.nn.init.xavier_uniform_(
114
+ self.conv.weight,
115
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
116
+ )
117
+
118
+ def forward(self, x):
119
+ x = self.conv(x)
120
+ x = x[:, :, : -self.padding]
121
+ return x
122
+
123
+
124
+ class CausualBlock(nn.Module):
125
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
126
+ super(CausualBlock, self).__init__()
127
+ self.blocks = nn.ModuleList(
128
+ [
129
+ self._get_conv(
130
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
131
+ )
132
+ for i in range(n_conv)
133
+ ]
134
+ )
135
+
136
+ def forward(self, x):
137
+ for block in self.blocks:
138
+ res = x
139
+ x = block(x)
140
+ x += res
141
+ return x
142
+
143
+ def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
144
+ layers = [
145
+ CausualConv(
146
+ hidden_dim,
147
+ hidden_dim,
148
+ kernel_size=3,
149
+ padding=dilation,
150
+ dilation=dilation,
151
+ ),
152
+ _get_activation_fn(activ),
153
+ nn.BatchNorm1d(hidden_dim),
154
+ nn.Dropout(p=dropout_p),
155
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
156
+ _get_activation_fn(activ),
157
+ nn.Dropout(p=dropout_p),
158
+ ]
159
+ return nn.Sequential(*layers)
160
+
161
+
162
+ class ConvBlock(nn.Module):
163
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
164
+ super().__init__()
165
+ self._n_groups = 8
166
+ self.blocks = nn.ModuleList(
167
+ [
168
+ self._get_conv(
169
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
170
+ )
171
+ for i in range(n_conv)
172
+ ]
173
+ )
174
+
175
+ def forward(self, x):
176
+ for block in self.blocks:
177
+ res = x
178
+ x = block(x)
179
+ x += res
180
+ return x
181
+
182
+ def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
183
+ layers = [
184
+ ConvNorm(
185
+ hidden_dim,
186
+ hidden_dim,
187
+ kernel_size=3,
188
+ padding=dilation,
189
+ dilation=dilation,
190
+ ),
191
+ _get_activation_fn(activ),
192
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
193
+ nn.Dropout(p=dropout_p),
194
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
195
+ _get_activation_fn(activ),
196
+ nn.Dropout(p=dropout_p),
197
+ ]
198
+ return nn.Sequential(*layers)
199
+
200
+
201
+ class LocationLayer(nn.Module):
202
+ def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
203
+ super(LocationLayer, self).__init__()
204
+ padding = int((attention_kernel_size - 1) / 2)
205
+ self.location_conv = ConvNorm(
206
+ 2,
207
+ attention_n_filters,
208
+ kernel_size=attention_kernel_size,
209
+ padding=padding,
210
+ bias=False,
211
+ stride=1,
212
+ dilation=1,
213
+ )
214
+ self.location_dense = LinearNorm(
215
+ attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
216
+ )
217
+
218
+ def forward(self, attention_weights_cat):
219
+ processed_attention = self.location_conv(attention_weights_cat)
220
+ processed_attention = processed_attention.transpose(1, 2)
221
+ processed_attention = self.location_dense(processed_attention)
222
+ return processed_attention
223
+
224
+
225
+ class Attention(nn.Module):
226
+ def __init__(
227
+ self,
228
+ attention_rnn_dim,
229
+ embedding_dim,
230
+ attention_dim,
231
+ attention_location_n_filters,
232
+ attention_location_kernel_size,
233
+ ):
234
+ super(Attention, self).__init__()
235
+ self.query_layer = LinearNorm(
236
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
237
+ )
238
+ self.memory_layer = LinearNorm(
239
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
240
+ )
241
+ self.v = LinearNorm(attention_dim, 1, bias=False)
242
+ self.location_layer = LocationLayer(
243
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
244
+ )
245
+ self.score_mask_value = -float("inf")
246
+
247
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
248
+ """
249
+ PARAMS
250
+ ------
251
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
252
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
253
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
254
+ RETURNS
255
+ -------
256
+ alignment (batch, max_time)
257
+ """
258
+
259
+ processed_query = self.query_layer(query.unsqueeze(1))
260
+ processed_attention_weights = self.location_layer(attention_weights_cat)
261
+ energies = self.v(
262
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
263
+ )
264
+
265
+ energies = energies.squeeze(-1)
266
+ return energies
267
+
268
+ def forward(
269
+ self,
270
+ attention_hidden_state,
271
+ memory,
272
+ processed_memory,
273
+ attention_weights_cat,
274
+ mask,
275
+ ):
276
+ """
277
+ PARAMS
278
+ ------
279
+ attention_hidden_state: attention rnn last output
280
+ memory: encoder outputs
281
+ processed_memory: processed encoder outputs
282
+ attention_weights_cat: previous and cummulative attention weights
283
+ mask: binary mask for padded data
284
+ """
285
+ alignment = self.get_alignment_energies(
286
+ attention_hidden_state, processed_memory, attention_weights_cat
287
+ )
288
+
289
+ if mask is not None:
290
+ alignment.data.masked_fill_(mask, self.score_mask_value)
291
+
292
+ attention_weights = F.softmax(alignment, dim=1)
293
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
294
+ attention_context = attention_context.squeeze(1)
295
+
296
+ return attention_context, attention_weights
297
+
298
+
299
+ class ForwardAttentionV2(nn.Module):
300
+ def __init__(
301
+ self,
302
+ attention_rnn_dim,
303
+ embedding_dim,
304
+ attention_dim,
305
+ attention_location_n_filters,
306
+ attention_location_kernel_size,
307
+ ):
308
+ super(ForwardAttentionV2, self).__init__()
309
+ self.query_layer = LinearNorm(
310
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
311
+ )
312
+ self.memory_layer = LinearNorm(
313
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
314
+ )
315
+ self.v = LinearNorm(attention_dim, 1, bias=False)
316
+ self.location_layer = LocationLayer(
317
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
318
+ )
319
+ self.score_mask_value = -float(1e20)
320
+
321
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
322
+ """
323
+ PARAMS
324
+ ------
325
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
326
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
327
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
328
+ RETURNS
329
+ -------
330
+ alignment (batch, max_time)
331
+ """
332
+
333
+ processed_query = self.query_layer(query.unsqueeze(1))
334
+ processed_attention_weights = self.location_layer(attention_weights_cat)
335
+ energies = self.v(
336
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
337
+ )
338
+
339
+ energies = energies.squeeze(-1)
340
+ return energies
341
+
342
+ def forward(
343
+ self,
344
+ attention_hidden_state,
345
+ memory,
346
+ processed_memory,
347
+ attention_weights_cat,
348
+ mask,
349
+ log_alpha,
350
+ ):
351
+ """
352
+ PARAMS
353
+ ------
354
+ attention_hidden_state: attention rnn last output
355
+ memory: encoder outputs
356
+ processed_memory: processed encoder outputs
357
+ attention_weights_cat: previous and cummulative attention weights
358
+ mask: binary mask for padded data
359
+ """
360
+ log_energy = self.get_alignment_energies(
361
+ attention_hidden_state, processed_memory, attention_weights_cat
362
+ )
363
+
364
+ # log_energy =
365
+
366
+ if mask is not None:
367
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
368
+
369
+ # attention_weights = F.softmax(alignment, dim=1)
370
+
371
+ # content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
372
+ # log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
373
+
374
+ # log_total_score = log_alpha + content_score
375
+
376
+ # previous_attention_weights = attention_weights_cat[:,0,:]
377
+
378
+ log_alpha_shift_padded = []
379
+ max_time = log_energy.size(1)
380
+ for sft in range(2):
381
+ shifted = log_alpha[:, : max_time - sft]
382
+ shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
383
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
384
+
385
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
386
+
387
+ log_alpha_new = biased + log_energy
388
+
389
+ attention_weights = F.softmax(log_alpha_new, dim=1)
390
+
391
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
392
+ attention_context = attention_context.squeeze(1)
393
+
394
+ return attention_context, attention_weights, log_alpha_new
395
+
396
+
397
+ class PhaseShuffle2d(nn.Module):
398
+ def __init__(self, n=2):
399
+ super(PhaseShuffle2d, self).__init__()
400
+ self.n = n
401
+ self.random = random.Random(1)
402
+
403
+ def forward(self, x, move=None):
404
+ # x.size = (B, C, M, L)
405
+ if move is None:
406
+ move = self.random.randint(-self.n, self.n)
407
+
408
+ if move == 0:
409
+ return x
410
+ else:
411
+ left = x[:, :, :, :move]
412
+ right = x[:, :, :, move:]
413
+ shuffled = torch.cat([right, left], dim=3)
414
+ return shuffled
415
+
416
+
417
+ class PhaseShuffle1d(nn.Module):
418
+ def __init__(self, n=2):
419
+ super(PhaseShuffle1d, self).__init__()
420
+ self.n = n
421
+ self.random = random.Random(1)
422
+
423
+ def forward(self, x, move=None):
424
+ # x.size = (B, C, M, L)
425
+ if move is None:
426
+ move = self.random.randint(-self.n, self.n)
427
+
428
+ if move == 0:
429
+ return x
430
+ else:
431
+ left = x[:, :, :move]
432
+ right = x[:, :, move:]
433
+ shuffled = torch.cat([right, left], dim=2)
434
+
435
+ return shuffled
436
+
437
+
438
+ class MFCC(nn.Module):
439
+ def __init__(self, n_mfcc=40, n_mels=80):
440
+ super(MFCC, self).__init__()
441
+ self.n_mfcc = n_mfcc
442
+ self.n_mels = n_mels
443
+ self.norm = "ortho"
444
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
445
+ self.register_buffer("dct_mat", dct_mat)
446
+
447
+ def forward(self, mel_specgram):
448
+ if len(mel_specgram.shape) == 2:
449
+ mel_specgram = mel_specgram.unsqueeze(0)
450
+ unsqueezed = True
451
+ else:
452
+ unsqueezed = False
453
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
454
+ # -> (channel, time, n_mfcc).tranpose(...)
455
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
456
+
457
+ # unpack batch
458
+ if unsqueezed:
459
+ mfcc = mfcc.squeeze(0)
460
+ return mfcc
models/codec/facodec/modules/quantize.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from modules.dac.nn.quantize import ResidualVectorQuantize
7
+ from torch import nn
8
+ from .wavenet import WN
9
+ from .style_encoder import StyleEncoder
10
+ from .gradient_reversal import GradientReversal
11
+ import torch
12
+ import torchaudio
13
+ import torchaudio.functional as audio_F
14
+ import numpy as np
15
+ from ..alias_free_torch import *
16
+ from torch.nn.utils import weight_norm
17
+ from torch import nn, sin, pow
18
+ from einops.layers.torch import Rearrange
19
+ from modules.dac.model.encodec import SConv1d
20
+
21
+
22
+ def init_weights(m):
23
+ if isinstance(m, nn.Conv1d):
24
+ nn.init.trunc_normal_(m.weight, std=0.02)
25
+ nn.init.constant_(m.bias, 0)
26
+
27
+
28
+ def WNConv1d(*args, **kwargs):
29
+ return weight_norm(nn.Conv1d(*args, **kwargs))
30
+
31
+
32
+ def WNConvTranspose1d(*args, **kwargs):
33
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
34
+
35
+
36
+ class SnakeBeta(nn.Module):
37
+ """
38
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
39
+ Shape:
40
+ - Input: (B, C, T)
41
+ - Output: (B, C, T), same shape as the input
42
+ Parameters:
43
+ - alpha - trainable parameter that controls frequency
44
+ - beta - trainable parameter that controls magnitude
45
+ References:
46
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
47
+ https://arxiv.org/abs/2006.08195
48
+ Examples:
49
+ >>> a1 = snakebeta(256)
50
+ >>> x = torch.randn(256)
51
+ >>> x = a1(x)
52
+ """
53
+
54
+ def __init__(
55
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
56
+ ):
57
+ """
58
+ Initialization.
59
+ INPUT:
60
+ - in_features: shape of the input
61
+ - alpha - trainable parameter that controls frequency
62
+ - beta - trainable parameter that controls magnitude
63
+ alpha is initialized to 1 by default, higher values = higher-frequency.
64
+ beta is initialized to 1 by default, higher values = higher-magnitude.
65
+ alpha will be trained along with the rest of your model.
66
+ """
67
+ super(SnakeBeta, self).__init__()
68
+ self.in_features = in_features
69
+
70
+ # initialize alpha
71
+ self.alpha_logscale = alpha_logscale
72
+ if self.alpha_logscale: # log scale alphas initialized to zeros
73
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
74
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
75
+ else: # linear scale alphas initialized to ones
76
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
77
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
78
+
79
+ self.alpha.requires_grad = alpha_trainable
80
+ self.beta.requires_grad = alpha_trainable
81
+
82
+ self.no_div_by_zero = 0.000000001
83
+
84
+ def forward(self, x):
85
+ """
86
+ Forward pass of the function.
87
+ Applies the function to the input elementwise.
88
+ SnakeBeta := x + 1/b * sin^2 (xa)
89
+ """
90
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
91
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
92
+ if self.alpha_logscale:
93
+ alpha = torch.exp(alpha)
94
+ beta = torch.exp(beta)
95
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
96
+
97
+ return x
98
+
99
+
100
+ class ResidualUnit(nn.Module):
101
+ def __init__(self, dim: int = 16, dilation: int = 1):
102
+ super().__init__()
103
+ pad = ((7 - 1) * dilation) // 2
104
+ self.block = nn.Sequential(
105
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
106
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
107
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
108
+ WNConv1d(dim, dim, kernel_size=1),
109
+ )
110
+
111
+ def forward(self, x):
112
+ return x + self.block(x)
113
+
114
+
115
+ class CNNLSTM(nn.Module):
116
+ def __init__(self, indim, outdim, head, global_pred=False):
117
+ super().__init__()
118
+ self.global_pred = global_pred
119
+ self.model = nn.Sequential(
120
+ ResidualUnit(indim, dilation=1),
121
+ ResidualUnit(indim, dilation=2),
122
+ ResidualUnit(indim, dilation=3),
123
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
124
+ Rearrange("b c t -> b t c"),
125
+ )
126
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
127
+
128
+ def forward(self, x):
129
+ # x: [B, C, T]
130
+ x = self.model(x)
131
+ if self.global_pred:
132
+ x = torch.mean(x, dim=1, keepdim=False)
133
+ outs = [head(x) for head in self.heads]
134
+ return outs
135
+
136
+
137
+ def sequence_mask(length, max_length=None):
138
+ if max_length is None:
139
+ max_length = length.max()
140
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
141
+ return x.unsqueeze(0) < length.unsqueeze(1)
142
+
143
+
144
+ class MFCC(nn.Module):
145
+ def __init__(self, n_mfcc=40, n_mels=80):
146
+ super(MFCC, self).__init__()
147
+ self.n_mfcc = n_mfcc
148
+ self.n_mels = n_mels
149
+ self.norm = "ortho"
150
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
151
+ self.register_buffer("dct_mat", dct_mat)
152
+
153
+ def forward(self, mel_specgram):
154
+ if len(mel_specgram.shape) == 2:
155
+ mel_specgram = mel_specgram.unsqueeze(0)
156
+ unsqueezed = True
157
+ else:
158
+ unsqueezed = False
159
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
160
+ # -> (channel, time, n_mfcc).tranpose(...)
161
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
162
+
163
+ # unpack batch
164
+ if unsqueezed:
165
+ mfcc = mfcc.squeeze(0)
166
+ return mfcc
167
+
168
+
169
+ class FAquantizer(nn.Module):
170
+ def __init__(
171
+ self,
172
+ in_dim=1024,
173
+ n_p_codebooks=1,
174
+ n_c_codebooks=2,
175
+ n_t_codebooks=2,
176
+ n_r_codebooks=3,
177
+ codebook_size=1024,
178
+ codebook_dim=8,
179
+ quantizer_dropout=0.5,
180
+ causal=False,
181
+ separate_prosody_encoder=False,
182
+ timbre_norm=False,
183
+ ):
184
+ super(FAquantizer, self).__init__()
185
+ conv1d_type = SConv1d # if causal else nn.Conv1d
186
+ self.prosody_quantizer = ResidualVectorQuantize(
187
+ input_dim=in_dim,
188
+ n_codebooks=n_p_codebooks,
189
+ codebook_size=codebook_size,
190
+ codebook_dim=codebook_dim,
191
+ quantizer_dropout=quantizer_dropout,
192
+ )
193
+
194
+ self.content_quantizer = ResidualVectorQuantize(
195
+ input_dim=in_dim,
196
+ n_codebooks=n_c_codebooks,
197
+ codebook_size=codebook_size,
198
+ codebook_dim=codebook_dim,
199
+ quantizer_dropout=quantizer_dropout,
200
+ )
201
+
202
+ if not timbre_norm:
203
+ self.timbre_quantizer = ResidualVectorQuantize(
204
+ input_dim=in_dim,
205
+ n_codebooks=n_t_codebooks,
206
+ codebook_size=codebook_size,
207
+ codebook_dim=codebook_dim,
208
+ quantizer_dropout=quantizer_dropout,
209
+ )
210
+ else:
211
+ self.timbre_encoder = StyleEncoder(
212
+ in_dim=80, hidden_dim=512, out_dim=in_dim
213
+ )
214
+ self.timbre_linear = nn.Linear(1024, 1024 * 2)
215
+ self.timbre_linear.bias.data[:1024] = 1
216
+ self.timbre_linear.bias.data[1024:] = 0
217
+ self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
218
+
219
+ self.residual_quantizer = ResidualVectorQuantize(
220
+ input_dim=in_dim,
221
+ n_codebooks=n_r_codebooks,
222
+ codebook_size=codebook_size,
223
+ codebook_dim=codebook_dim,
224
+ quantizer_dropout=quantizer_dropout,
225
+ )
226
+
227
+ if separate_prosody_encoder:
228
+ self.melspec_linear = conv1d_type(
229
+ in_channels=20, out_channels=256, kernel_size=1, causal=causal
230
+ )
231
+ self.melspec_encoder = WN(
232
+ hidden_channels=256,
233
+ kernel_size=5,
234
+ dilation_rate=1,
235
+ n_layers=8,
236
+ gin_channels=0,
237
+ p_dropout=0.2,
238
+ causal=causal,
239
+ )
240
+ self.melspec_linear2 = conv1d_type(
241
+ in_channels=256, out_channels=1024, kernel_size=1, causal=causal
242
+ )
243
+ else:
244
+ pass
245
+ self.separate_prosody_encoder = separate_prosody_encoder
246
+
247
+ self.prob_random_mask_residual = 0.75
248
+
249
+ SPECT_PARAMS = {
250
+ "n_fft": 2048,
251
+ "win_length": 1200,
252
+ "hop_length": 300,
253
+ }
254
+ MEL_PARAMS = {
255
+ "n_mels": 80,
256
+ }
257
+
258
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
259
+ n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
260
+ )
261
+ self.mel_mean, self.mel_std = -4, 4
262
+ self.frame_rate = 24000 / 300
263
+ self.hop_length = 300
264
+
265
+ self.is_timbre_norm = timbre_norm
266
+ if timbre_norm:
267
+ self.forward = self.forward_v2
268
+
269
+ def preprocess(self, wave_tensor, n_bins=20):
270
+ mel_tensor = self.to_mel(wave_tensor.squeeze(1))
271
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
272
+ return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)]
273
+
274
+ @torch.no_grad()
275
+ def decode(self, codes):
276
+ code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
277
+
278
+ z_c = self.content_quantizer.from_codes(code_c)[0]
279
+ z_p = self.prosody_quantizer.from_codes(code_p)[0]
280
+ z_t = self.timbre_quantizer.from_codes(code_t)[0]
281
+
282
+ z = z_c + z_p + z_t
283
+
284
+ return z, [z_c, z_p, z_t]
285
+
286
+ @torch.no_grad()
287
+ def encode(self, x, wave_segments, n_c=1):
288
+ outs = 0
289
+ if self.separate_prosody_encoder:
290
+ prosody_feature = self.preprocess(wave_segments)
291
+
292
+ f0_input = prosody_feature # (B, T, 20)
293
+ f0_input = self.melspec_linear(f0_input)
294
+ f0_input = self.melspec_encoder(
295
+ f0_input,
296
+ torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
297
+ .to(f0_input.device)
298
+ .bool(),
299
+ )
300
+ f0_input = self.melspec_linear2(f0_input)
301
+
302
+ common_min_size = min(f0_input.size(2), x.size(2))
303
+ f0_input = f0_input[:, :, :common_min_size]
304
+
305
+ x = x[:, :, :common_min_size]
306
+
307
+ (
308
+ z_p,
309
+ codes_p,
310
+ latents_p,
311
+ commitment_loss_p,
312
+ codebook_loss_p,
313
+ ) = self.prosody_quantizer(f0_input, 1)
314
+ outs += z_p.detach()
315
+ else:
316
+ (
317
+ z_p,
318
+ codes_p,
319
+ latents_p,
320
+ commitment_loss_p,
321
+ codebook_loss_p,
322
+ ) = self.prosody_quantizer(x, 1)
323
+ outs += z_p.detach()
324
+
325
+ (
326
+ z_c,
327
+ codes_c,
328
+ latents_c,
329
+ commitment_loss_c,
330
+ codebook_loss_c,
331
+ ) = self.content_quantizer(x, n_c)
332
+ outs += z_c.detach()
333
+
334
+ timbre_residual_feature = x - z_p.detach() - z_c.detach()
335
+
336
+ (
337
+ z_t,
338
+ codes_t,
339
+ latents_t,
340
+ commitment_loss_t,
341
+ codebook_loss_t,
342
+ ) = self.timbre_quantizer(timbre_residual_feature, 2)
343
+ outs += z_t # we should not detach timbre
344
+
345
+ residual_feature = timbre_residual_feature - z_t
346
+
347
+ (
348
+ z_r,
349
+ codes_r,
350
+ latents_r,
351
+ commitment_loss_r,
352
+ codebook_loss_r,
353
+ ) = self.residual_quantizer(residual_feature, 3)
354
+
355
+ return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
356
+
357
+ def forward(
358
+ self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2
359
+ ):
360
+ # timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
361
+ # timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
362
+ outs = 0
363
+ if self.separate_prosody_encoder:
364
+ prosody_feature = self.preprocess(wave_segments)
365
+
366
+ f0_input = prosody_feature # (B, T, 20)
367
+ f0_input = self.melspec_linear(f0_input)
368
+ f0_input = self.melspec_encoder(
369
+ f0_input,
370
+ torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
371
+ .to(f0_input.device)
372
+ .bool(),
373
+ )
374
+ f0_input = self.melspec_linear2(f0_input)
375
+
376
+ common_min_size = min(f0_input.size(2), x.size(2))
377
+ f0_input = f0_input[:, :, :common_min_size]
378
+
379
+ x = x[:, :, :common_min_size]
380
+
381
+ (
382
+ z_p,
383
+ codes_p,
384
+ latents_p,
385
+ commitment_loss_p,
386
+ codebook_loss_p,
387
+ ) = self.prosody_quantizer(f0_input, 1)
388
+ outs += z_p.detach()
389
+ else:
390
+ (
391
+ z_p,
392
+ codes_p,
393
+ latents_p,
394
+ commitment_loss_p,
395
+ codebook_loss_p,
396
+ ) = self.prosody_quantizer(x, 1)
397
+ outs += z_p.detach()
398
+
399
+ (
400
+ z_c,
401
+ codes_c,
402
+ latents_c,
403
+ commitment_loss_c,
404
+ codebook_loss_c,
405
+ ) = self.content_quantizer(x, n_c)
406
+ outs += z_c.detach()
407
+
408
+ timbre_residual_feature = x - z_p.detach() - z_c.detach()
409
+
410
+ (
411
+ z_t,
412
+ codes_t,
413
+ latents_t,
414
+ commitment_loss_t,
415
+ codebook_loss_t,
416
+ ) = self.timbre_quantizer(timbre_residual_feature, n_t)
417
+ outs += z_t # we should not detach timbre
418
+
419
+ residual_feature = timbre_residual_feature - z_t
420
+
421
+ (
422
+ z_r,
423
+ codes_r,
424
+ latents_r,
425
+ commitment_loss_r,
426
+ codebook_loss_r,
427
+ ) = self.residual_quantizer(residual_feature, 3)
428
+
429
+ bsz = z_r.shape[0]
430
+ res_mask = np.random.choice(
431
+ [0, 1],
432
+ size=bsz,
433
+ p=[
434
+ self.prob_random_mask_residual,
435
+ 1 - self.prob_random_mask_residual,
436
+ ],
437
+ )
438
+ res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
439
+ res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
440
+ noise_must_on = noise_added_flags * recon_noisy_flags
441
+ noise_must_off = noise_added_flags * (~recon_noisy_flags)
442
+ res_mask[noise_must_on] = 1
443
+ res_mask[noise_must_off] = 0
444
+
445
+ outs += z_r * res_mask
446
+
447
+ quantized = [z_p, z_c, z_t, z_r]
448
+ commitment_losses = (
449
+ commitment_loss_p
450
+ + commitment_loss_c
451
+ + commitment_loss_t
452
+ + commitment_loss_r
453
+ )
454
+ codebook_losses = (
455
+ codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
456
+ )
457
+
458
+ return outs, quantized, commitment_losses, codebook_losses
459
+
460
+ def forward_v2(
461
+ self,
462
+ x,
463
+ wave_segments,
464
+ n_c=1,
465
+ n_t=2,
466
+ full_waves=None,
467
+ wave_lens=None,
468
+ return_codes=False,
469
+ ):
470
+ # timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
471
+ if full_waves is None:
472
+ mel = self.preprocess(wave_segments, n_bins=80)
473
+ timbre = self.timbre_encoder(
474
+ mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)
475
+ )
476
+ else:
477
+ mel = self.preprocess(full_waves, n_bins=80)
478
+ timbre = self.timbre_encoder(
479
+ mel,
480
+ sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1),
481
+ )
482
+ outs = 0
483
+ if self.separate_prosody_encoder:
484
+ prosody_feature = self.preprocess(wave_segments)
485
+
486
+ f0_input = prosody_feature # (B, T, 20)
487
+ f0_input = self.melspec_linear(f0_input)
488
+ f0_input = self.melspec_encoder(
489
+ f0_input,
490
+ torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
491
+ .to(f0_input.device)
492
+ .bool(),
493
+ )
494
+ f0_input = self.melspec_linear2(f0_input)
495
+
496
+ common_min_size = min(f0_input.size(2), x.size(2))
497
+ f0_input = f0_input[:, :, :common_min_size]
498
+
499
+ x = x[:, :, :common_min_size]
500
+
501
+ (
502
+ z_p,
503
+ codes_p,
504
+ latents_p,
505
+ commitment_loss_p,
506
+ codebook_loss_p,
507
+ ) = self.prosody_quantizer(f0_input, 1)
508
+ outs += z_p.detach()
509
+ else:
510
+ (
511
+ z_p,
512
+ codes_p,
513
+ latents_p,
514
+ commitment_loss_p,
515
+ codebook_loss_p,
516
+ ) = self.prosody_quantizer(x, 1)
517
+ outs += z_p.detach()
518
+
519
+ (
520
+ z_c,
521
+ codes_c,
522
+ latents_c,
523
+ commitment_loss_c,
524
+ codebook_loss_c,
525
+ ) = self.content_quantizer(x, n_c)
526
+ outs += z_c.detach()
527
+
528
+ residual_feature = x - z_p.detach() - z_c.detach()
529
+
530
+ (
531
+ z_r,
532
+ codes_r,
533
+ latents_r,
534
+ commitment_loss_r,
535
+ codebook_loss_r,
536
+ ) = self.residual_quantizer(residual_feature, 3)
537
+
538
+ bsz = z_r.shape[0]
539
+ res_mask = np.random.choice(
540
+ [0, 1],
541
+ size=bsz,
542
+ p=[
543
+ self.prob_random_mask_residual,
544
+ 1 - self.prob_random_mask_residual,
545
+ ],
546
+ )
547
+ res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
548
+ res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
549
+
550
+ if not self.training:
551
+ res_mask = torch.ones_like(res_mask)
552
+ outs += z_r * res_mask
553
+
554
+ quantized = [z_p, z_c, z_r]
555
+ codes = [codes_p, codes_c, codes_r]
556
+ commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
557
+ codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
558
+
559
+ style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1)
560
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
561
+ outs = outs.transpose(1, 2)
562
+ outs = self.timbre_norm(outs)
563
+ outs = outs.transpose(1, 2)
564
+ outs = outs * gamma + beta
565
+
566
+ if return_codes:
567
+ return outs, quantized, commitment_losses, codebook_losses, timbre, codes
568
+ else:
569
+ return outs, quantized, commitment_losses, codebook_losses, timbre
570
+
571
+ def voice_conversion(self, z, ref_wave):
572
+ ref_mel = self.preprocess(ref_wave, n_bins=80)
573
+ ref_timbre = self.timbre_encoder(
574
+ ref_mel,
575
+ sequence_mask(
576
+ torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length,
577
+ ref_mel.size(-1),
578
+ ).unsqueeze(1),
579
+ )
580
+ style = self.timbre_linear(ref_timbre).unsqueeze(2) # (B, 2d, 1)
581
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
582
+ outs = z.transpose(1, 2)
583
+ outs = self.timbre_norm(outs)
584
+ outs = outs.transpose(1, 2)
585
+ outs = outs * gamma + beta
586
+
587
+ return outs
588
+
589
+
590
+ class FApredictors(nn.Module):
591
+ def __init__(
592
+ self,
593
+ in_dim=1024,
594
+ use_gr_content_f0=False,
595
+ use_gr_prosody_phone=False,
596
+ use_gr_residual_f0=False,
597
+ use_gr_residual_phone=False,
598
+ use_gr_timbre_content=True,
599
+ use_gr_timbre_prosody=True,
600
+ use_gr_x_timbre=False,
601
+ norm_f0=True,
602
+ timbre_norm=False,
603
+ use_gr_content_global_f0=False,
604
+ ):
605
+ super(FApredictors, self).__init__()
606
+ self.f0_predictor = CNNLSTM(in_dim, 1, 2)
607
+ self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
608
+ if timbre_norm:
609
+ self.timbre_predictor = nn.Linear(in_dim, 20000)
610
+ else:
611
+ self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
612
+
613
+ self.use_gr_content_f0 = use_gr_content_f0
614
+ self.use_gr_prosody_phone = use_gr_prosody_phone
615
+ self.use_gr_residual_f0 = use_gr_residual_f0
616
+ self.use_gr_residual_phone = use_gr_residual_phone
617
+ self.use_gr_timbre_content = use_gr_timbre_content
618
+ self.use_gr_timbre_prosody = use_gr_timbre_prosody
619
+ self.use_gr_x_timbre = use_gr_x_timbre
620
+
621
+ self.rev_f0_predictor = nn.Sequential(
622
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
623
+ )
624
+ self.rev_content_predictor = nn.Sequential(
625
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
626
+ )
627
+ self.rev_timbre_predictor = nn.Sequential(
628
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
629
+ )
630
+
631
+ self.norm_f0 = norm_f0
632
+ self.timbre_norm = timbre_norm
633
+ if timbre_norm:
634
+ self.forward = self.forward_v2
635
+ self.global_f0_predictor = nn.Linear(in_dim, 1)
636
+
637
+ self.use_gr_content_global_f0 = use_gr_content_global_f0
638
+ if use_gr_content_global_f0:
639
+ self.rev_global_f0_predictor = nn.Sequential(
640
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
641
+ )
642
+
643
+ def forward(self, quantized):
644
+ prosody_latent = quantized[0]
645
+ content_latent = quantized[1]
646
+ timbre_latent = quantized[2]
647
+ residual_latent = quantized[3]
648
+ content_pred = self.phone_predictor(content_latent)[0]
649
+
650
+ if self.norm_f0:
651
+ spk_pred = self.timbre_predictor(timbre_latent)[0]
652
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent)
653
+ else:
654
+ spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
655
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
656
+
657
+ prosody_rev_latent = torch.zeros_like(quantized[0])
658
+ if self.use_gr_content_f0:
659
+ prosody_rev_latent += quantized[1]
660
+ if self.use_gr_timbre_prosody:
661
+ prosody_rev_latent += quantized[2]
662
+ if self.use_gr_residual_f0:
663
+ prosody_rev_latent += quantized[3]
664
+ rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
665
+
666
+ content_rev_latent = torch.zeros_like(quantized[1])
667
+ if self.use_gr_prosody_phone:
668
+ content_rev_latent += quantized[0]
669
+ if self.use_gr_timbre_content:
670
+ content_rev_latent += quantized[2]
671
+ if self.use_gr_residual_phone:
672
+ content_rev_latent += quantized[3]
673
+ rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
674
+
675
+ if self.norm_f0:
676
+ timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
677
+ else:
678
+ timbre_rev_latent = quantized[1] + quantized[3]
679
+ if self.use_gr_x_timbre:
680
+ x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
681
+ else:
682
+ x_spk_pred = None
683
+
684
+ preds = {
685
+ "f0": f0_pred,
686
+ "uv": uv_pred,
687
+ "content": content_pred,
688
+ "timbre": spk_pred,
689
+ }
690
+
691
+ rev_preds = {
692
+ "rev_f0": rev_f0_pred,
693
+ "rev_uv": rev_uv_pred,
694
+ "rev_content": rev_content_pred,
695
+ "x_timbre": x_spk_pred,
696
+ }
697
+ return preds, rev_preds
698
+
699
+ def forward_v2(self, quantized, timbre):
700
+ prosody_latent = quantized[0]
701
+ content_latent = quantized[1]
702
+ residual_latent = quantized[2]
703
+ content_pred = self.phone_predictor(content_latent)[0]
704
+
705
+ spk_pred = self.timbre_predictor(timbre)
706
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent)
707
+
708
+ prosody_rev_latent = torch.zeros_like(prosody_latent)
709
+ if self.use_gr_content_f0:
710
+ prosody_rev_latent += content_latent
711
+ if self.use_gr_residual_f0:
712
+ prosody_rev_latent += residual_latent
713
+ rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
714
+
715
+ content_rev_latent = torch.zeros_like(content_latent)
716
+ if self.use_gr_prosody_phone:
717
+ content_rev_latent += prosody_latent
718
+ if self.use_gr_residual_phone:
719
+ content_rev_latent += residual_latent
720
+ rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
721
+
722
+ timbre_rev_latent = prosody_latent + content_latent + residual_latent
723
+ if self.use_gr_x_timbre:
724
+ x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
725
+ else:
726
+ x_spk_pred = None
727
+
728
+ preds = {
729
+ "f0": f0_pred,
730
+ "uv": uv_pred,
731
+ "content": content_pred,
732
+ "timbre": spk_pred,
733
+ }
734
+
735
+ rev_preds = {
736
+ "rev_f0": rev_f0_pred,
737
+ "rev_uv": rev_uv_pred,
738
+ "rev_content": rev_content_pred,
739
+ "x_timbre": x_spk_pred,
740
+ }
741
+ return preds, rev_preds
models/codec/facodec/modules/style_encoder.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
7
+
8
+ from . import attentions
9
+ from torch import nn
10
+ import torch
11
+ from torch.nn import functional as F
12
+
13
+
14
+ class Mish(nn.Module):
15
+ def __init__(self):
16
+ super(Mish, self).__init__()
17
+
18
+ def forward(self, x):
19
+ return x * torch.tanh(F.softplus(x))
20
+
21
+
22
+ class Conv1dGLU(nn.Module):
23
+ """
24
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
25
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
26
+ """
27
+
28
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
29
+ super(Conv1dGLU, self).__init__()
30
+ self.out_channels = out_channels
31
+ self.conv1 = nn.Conv1d(
32
+ in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2
33
+ )
34
+ self.dropout = nn.Dropout(dropout)
35
+
36
+ def forward(self, x):
37
+ residual = x
38
+ x = self.conv1(x)
39
+ x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
40
+ x = x1 * torch.sigmoid(x2)
41
+ x = residual + self.dropout(x)
42
+ return x
43
+
44
+
45
+ class StyleEncoder(torch.nn.Module):
46
+ def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
47
+
48
+ super().__init__()
49
+
50
+ self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
51
+ self.hidden_dim = hidden_dim
52
+ self.out_dim = out_dim
53
+ self.kernel_size = 5
54
+ self.n_head = 2
55
+ self.dropout = 0.1
56
+
57
+ self.spectral = nn.Sequential(
58
+ nn.Conv1d(self.in_dim, self.hidden_dim, 1),
59
+ Mish(),
60
+ nn.Dropout(self.dropout),
61
+ nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
62
+ Mish(),
63
+ nn.Dropout(self.dropout),
64
+ )
65
+
66
+ self.temporal = nn.Sequential(
67
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
68
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
69
+ )
70
+
71
+ self.slf_attn = attentions.MultiHeadAttention(
72
+ self.hidden_dim,
73
+ self.hidden_dim,
74
+ self.n_head,
75
+ p_dropout=self.dropout,
76
+ proximal_bias=False,
77
+ proximal_init=True,
78
+ )
79
+ self.atten_drop = nn.Dropout(self.dropout)
80
+ self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
81
+
82
+ def forward(self, x, mask=None):
83
+
84
+ # spectral
85
+ x = self.spectral(x) * mask
86
+ # temporal
87
+ x = self.temporal(x) * mask
88
+
89
+ # self-attention
90
+ attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
91
+ y = self.slf_attn(x, x, attn_mask=attn_mask)
92
+ x = x + self.atten_drop(y)
93
+
94
+ # fc
95
+ x = self.fc(x)
96
+
97
+ # temoral average pooling
98
+ w = self.temporal_avg_pool(x, mask=mask)
99
+
100
+ return w
101
+
102
+ def temporal_avg_pool(self, x, mask=None):
103
+ if mask is None:
104
+ out = torch.mean(x, dim=2)
105
+ else:
106
+ len_ = mask.sum(dim=2)
107
+ x = x.sum(dim=2)
108
+
109
+ out = torch.div(x, len_)
110
+ return out
models/codec/facodec/modules/wavenet.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
7
+
8
+ import math
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from modules.dac.model.encodec import SConv1d
14
+
15
+ from . import commons
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = p_dropout
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g=None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ causal=False,
146
+ ):
147
+ super(WN, self).__init__()
148
+ conv1d_type = SConv1d
149
+ assert kernel_size % 2 == 1
150
+ self.hidden_channels = hidden_channels
151
+ self.kernel_size = (kernel_size,)
152
+ self.dilation_rate = dilation_rate
153
+ self.n_layers = n_layers
154
+ self.gin_channels = gin_channels
155
+ self.p_dropout = p_dropout
156
+
157
+ self.in_layers = torch.nn.ModuleList()
158
+ self.res_skip_layers = torch.nn.ModuleList()
159
+ self.drop = nn.Dropout(p_dropout)
160
+
161
+ if gin_channels != 0:
162
+ self.cond_layer = conv1d_type(
163
+ gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm"
164
+ )
165
+
166
+ for i in range(n_layers):
167
+ dilation = dilation_rate**i
168
+ padding = int((kernel_size * dilation - dilation) / 2)
169
+ in_layer = conv1d_type(
170
+ hidden_channels,
171
+ 2 * hidden_channels,
172
+ kernel_size,
173
+ dilation=dilation,
174
+ padding=padding,
175
+ norm="weight_norm",
176
+ causal=causal,
177
+ )
178
+ self.in_layers.append(in_layer)
179
+
180
+ # last one is not necessary
181
+ if i < n_layers - 1:
182
+ res_skip_channels = 2 * hidden_channels
183
+ else:
184
+ res_skip_channels = hidden_channels
185
+
186
+ res_skip_layer = conv1d_type(
187
+ hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal
188
+ )
189
+ self.res_skip_layers.append(res_skip_layer)
190
+
191
+ def forward(self, x, x_mask, g=None, **kwargs):
192
+ output = torch.zeros_like(x)
193
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
194
+
195
+ if g is not None:
196
+ g = self.cond_layer(g)
197
+
198
+ for i in range(self.n_layers):
199
+ x_in = self.in_layers[i](x)
200
+ if g is not None:
201
+ cond_offset = i * 2 * self.hidden_channels
202
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
203
+ else:
204
+ g_l = torch.zeros_like(x_in)
205
+
206
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
207
+ acts = self.drop(acts)
208
+
209
+ res_skip_acts = self.res_skip_layers[i](acts)
210
+ if i < self.n_layers - 1:
211
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
212
+ x = (x + res_acts) * x_mask
213
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
214
+ else:
215
+ output = output + res_skip_acts
216
+ return output * x_mask
217
+
218
+ def remove_weight_norm(self):
219
+ if self.gin_channels != 0:
220
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
221
+ for l in self.in_layers:
222
+ torch.nn.utils.remove_weight_norm(l)
223
+ for l in self.res_skip_layers:
224
+ torch.nn.utils.remove_weight_norm(l)
models/codec/facodec/optimizer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os, sys
7
+ import os.path as osp
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+ from torch.optim import Optimizer
12
+ from functools import reduce
13
+ from torch.optim import AdamW
14
+
15
+
16
+ class MultiOptimizer:
17
+ def __init__(self, optimizers={}, schedulers={}):
18
+ self.optimizers = optimizers
19
+ self.schedulers = schedulers
20
+ self.keys = list(optimizers.keys())
21
+ self.param_groups = reduce(
22
+ lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
23
+ )
24
+
25
+ def state_dict(self):
26
+ state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
27
+ return state_dicts
28
+
29
+ def scheduler_state_dict(self):
30
+ state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys]
31
+ return state_dicts
32
+
33
+ def load_state_dict(self, state_dict):
34
+ for key, val in state_dict:
35
+ try:
36
+ self.optimizers[key].load_state_dict(val)
37
+ except:
38
+ print("Unloaded %s" % key)
39
+
40
+ def load_scheduler_state_dict(self, state_dict):
41
+ for key, val in state_dict:
42
+ try:
43
+ self.schedulers[key].load_state_dict(val)
44
+ except:
45
+ print("Unloaded %s" % key)
46
+
47
+ def step(self, key=None, scaler=None):
48
+ keys = [key] if key is not None else self.keys
49
+ _ = [self._step(key, scaler) for key in keys]
50
+
51
+ def _step(self, key, scaler=None):
52
+ if scaler is not None:
53
+ scaler.step(self.optimizers[key])
54
+ scaler.update()
55
+ else:
56
+ self.optimizers[key].step()
57
+
58
+ def zero_grad(self, key=None):
59
+ if key is not None:
60
+ self.optimizers[key].zero_grad()
61
+ else:
62
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
63
+
64
+ def scheduler(self, *args, key=None):
65
+ if key is not None:
66
+ self.schedulers[key].step(*args)
67
+ else:
68
+ _ = [self.schedulers[key].step_batch(*args) for key in self.keys]
69
+
70
+
71
+ def define_scheduler(optimizer, params):
72
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"])
73
+
74
+ return scheduler
75
+
76
+
77
+ def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"):
78
+ optim = {}
79
+ for key, model in model_dict.items():
80
+ model_parameters = model.parameters()
81
+ parameters_names = []
82
+ parameters_names.append(
83
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
84
+ )
85
+ if type == "AdamW":
86
+ optim[key] = AdamW(
87
+ model_parameters,
88
+ lr=lr,
89
+ betas=(0.9, 0.98),
90
+ eps=1e-9,
91
+ weight_decay=0.1,
92
+ )
93
+ else:
94
+ raise ValueError("Unknown optimizer type: %s" % type)
95
+
96
+ schedulers = dict(
97
+ [
98
+ (key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
99
+ for key, opt in optim.items()
100
+ ]
101
+ )
102
+
103
+ multi_optim = MultiOptimizer(optim, schedulers)
104
+ return multi_optim
models/codec/kmeans/repcodec_model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from concurrent.futures import ALL_COMPLETED
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from torch.nn import functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from models.codec.amphion_codec.quantize import ResidualVQ
15
+ from models.codec.kmeans.vocos import VocosBackbone
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+ if isinstance(m, nn.Linear):
23
+ nn.init.trunc_normal_(m.weight, std=0.02)
24
+ nn.init.constant_(m.bias, 0)
25
+
26
+
27
+ def compute_codebook_perplexity(indices, codebook_size):
28
+ indices = indices.flatten()
29
+ prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
30
+ perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
31
+ return perp
32
+
33
+
34
+ class RepCodec(nn.Module):
35
+ def __init__(
36
+ self,
37
+ codebook_size=8192,
38
+ hidden_size=1024,
39
+ codebook_dim=8,
40
+ vocos_dim=384,
41
+ vocos_intermediate_dim=2048,
42
+ vocos_num_layers=12,
43
+ num_quantizers=1,
44
+ downsample_scale=1,
45
+ cfg=None,
46
+ ):
47
+ super().__init__()
48
+ codebook_size = (
49
+ cfg.codebook_size
50
+ if cfg is not None and hasattr(cfg, "codebook_size")
51
+ else codebook_size
52
+ )
53
+ codebook_dim = (
54
+ cfg.codebook_dim
55
+ if cfg is not None and hasattr(cfg, "codebook_dim")
56
+ else codebook_dim
57
+ )
58
+ hidden_size = (
59
+ cfg.hidden_size
60
+ if cfg is not None and hasattr(cfg, "hidden_size")
61
+ else hidden_size
62
+ )
63
+ vocos_dim = (
64
+ cfg.vocos_dim
65
+ if cfg is not None and hasattr(cfg, "vocos_dim")
66
+ else vocos_dim
67
+ )
68
+ vocos_intermediate_dim = (
69
+ cfg.vocos_intermediate_dim
70
+ if cfg is not None and hasattr(cfg, "vocos_dim")
71
+ else vocos_intermediate_dim
72
+ )
73
+ vocos_num_layers = (
74
+ cfg.vocos_num_layers
75
+ if cfg is not None and hasattr(cfg, "vocos_dim")
76
+ else vocos_num_layers
77
+ )
78
+ num_quantizers = (
79
+ cfg.num_quantizers
80
+ if cfg is not None and hasattr(cfg, "num_quantizers")
81
+ else num_quantizers
82
+ )
83
+ downsample_scale = (
84
+ cfg.downsample_scale
85
+ if cfg is not None and hasattr(cfg, "downsample_scale")
86
+ else downsample_scale
87
+ )
88
+
89
+ self.codebook_size = codebook_size
90
+ self.codebook_dim = codebook_dim
91
+ self.hidden_size = hidden_size
92
+ self.vocos_dim = vocos_dim
93
+ self.vocos_intermediate_dim = vocos_intermediate_dim
94
+ self.vocos_num_layers = vocos_num_layers
95
+ self.num_quantizers = num_quantizers
96
+ self.downsample_scale = downsample_scale
97
+
98
+ if self.downsample_scale != None and self.downsample_scale > 1:
99
+ self.down = nn.Conv1d(
100
+ self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
101
+ )
102
+ self.up = nn.Conv1d(
103
+ self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
104
+ )
105
+
106
+ self.encoder = nn.Sequential(
107
+ VocosBackbone(
108
+ input_channels=self.hidden_size,
109
+ dim=self.vocos_dim,
110
+ intermediate_dim=self.vocos_intermediate_dim,
111
+ num_layers=self.vocos_num_layers,
112
+ adanorm_num_embeddings=None,
113
+ ),
114
+ nn.Linear(self.vocos_dim, self.hidden_size),
115
+ )
116
+ self.decoder = nn.Sequential(
117
+ VocosBackbone(
118
+ input_channels=self.hidden_size,
119
+ dim=self.vocos_dim,
120
+ intermediate_dim=self.vocos_intermediate_dim,
121
+ num_layers=self.vocos_num_layers,
122
+ adanorm_num_embeddings=None,
123
+ ),
124
+ nn.Linear(self.vocos_dim, self.hidden_size),
125
+ )
126
+
127
+ self.quantizer = ResidualVQ(
128
+ input_dim=hidden_size,
129
+ num_quantizers=num_quantizers,
130
+ codebook_size=codebook_size,
131
+ codebook_dim=codebook_dim,
132
+ quantizer_type="fvq",
133
+ quantizer_dropout=0.0,
134
+ commitment=0.15,
135
+ codebook_loss_weight=1.0,
136
+ use_l2_normlize=True,
137
+ )
138
+
139
+ self.reset_parameters()
140
+
141
+ def forward(self, x):
142
+
143
+ # downsample
144
+ if self.downsample_scale != None and self.downsample_scale > 1:
145
+ x = x.transpose(1, 2)
146
+ x = self.down(x)
147
+ x = F.gelu(x)
148
+ x = x.transpose(1, 2)
149
+
150
+ # encoder
151
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
152
+
153
+ # vq
154
+ (
155
+ quantized_out,
156
+ all_indices,
157
+ all_commit_losses,
158
+ all_codebook_losses,
159
+ _,
160
+ ) = self.quantizer(x)
161
+
162
+ # decoder
163
+ x = self.decoder(quantized_out)
164
+
165
+ # up
166
+ if self.downsample_scale != None and self.downsample_scale > 1:
167
+ x = x.transpose(1, 2)
168
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
169
+ x_rec = self.up(x).transpose(1, 2)
170
+
171
+ codebook_loss = (all_codebook_losses + all_commit_losses).mean()
172
+ all_indices = all_indices
173
+
174
+ return x_rec, codebook_loss, all_indices
175
+
176
+ def quantize(self, x):
177
+
178
+ if self.downsample_scale != None and self.downsample_scale > 1:
179
+ x = x.transpose(1, 2)
180
+ x = self.down(x)
181
+ x = F.gelu(x)
182
+ x = x.transpose(1, 2)
183
+
184
+ x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
185
+
186
+ (
187
+ quantized_out,
188
+ all_indices,
189
+ all_commit_losses,
190
+ all_codebook_losses,
191
+ _,
192
+ ) = self.quantizer(x)
193
+
194
+ if all_indices.shape[0] == 1:
195
+ return all_indices.squeeze(0), quantized_out.transpose(1, 2)
196
+ return all_indices, quantized_out.transpose(1, 2)
197
+
198
+ def reset_parameters(self):
199
+ self.apply(init_weights)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
204
+ print(repcodec)
205
+ print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
206
+ x = torch.randn(5, 10, 1024)
207
+ x_rec, codebook_loss, all_indices = repcodec(x)
208
+ print(x_rec.shape, codebook_loss, all_indices.shape)
209
+ vq_id, emb = repcodec.quantize(x)
210
+ print(vq_id.shape, emb.shape)
models/codec/kmeans/vocos.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import scipy
10
+ import torch
11
+ from torch import nn, view_as_real, view_as_complex
12
+ from torch import nn
13
+ from torch.nn.utils import weight_norm, remove_weight_norm
14
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
15
+
16
+
17
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
18
+ """
19
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
20
+
21
+ Args:
22
+ x (Tensor): Input tensor.
23
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
24
+
25
+ Returns:
26
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
27
+ """
28
+ return torch.log(torch.clip(x, min=clip_val))
29
+
30
+
31
+ def symlog(x: torch.Tensor) -> torch.Tensor:
32
+ return torch.sign(x) * torch.log1p(x.abs())
33
+
34
+
35
+ def symexp(x: torch.Tensor) -> torch.Tensor:
36
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
37
+
38
+
39
+ class STFT(nn.Module):
40
+ def __init__(
41
+ self,
42
+ n_fft: int,
43
+ hop_length: int,
44
+ win_length: int,
45
+ center=True,
46
+ ):
47
+ super().__init__()
48
+ self.center = center
49
+ self.n_fft = n_fft
50
+ self.hop_length = hop_length
51
+ self.win_length = win_length
52
+ window = torch.hann_window(win_length)
53
+ self.register_buffer("window", window)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ # x: (B, T * hop_length)
57
+
58
+ if not self.center:
59
+ pad = self.win_length - self.hop_length
60
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
61
+
62
+ stft_spec = torch.stft(
63
+ x,
64
+ self.n_fft,
65
+ hop_length=self.hop_length,
66
+ win_length=self.win_length,
67
+ window=self.window,
68
+ center=self.center,
69
+ return_complex=False,
70
+ ) # (B, n_fft // 2 + 1, T, 2)
71
+
72
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
73
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
74
+
75
+ log_mag = torch.log(
76
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
77
+ ) # (B, n_fft // 2 + 1, T)
78
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
79
+
80
+ return log_mag, phase
81
+
82
+
83
+ class ISTFT(nn.Module):
84
+ """
85
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
86
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
87
+ See issue: https://github.com/pytorch/pytorch/issues/62323
88
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
89
+ The NOLA constraint is met as we trim padded samples anyway.
90
+
91
+ Args:
92
+ n_fft (int): Size of Fourier transform.
93
+ hop_length (int): The distance between neighboring sliding window frames.
94
+ win_length (int): The size of window frame and STFT filter.
95
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
96
+ """
97
+
98
+ def __init__(
99
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
100
+ ):
101
+ super().__init__()
102
+ if padding not in ["center", "same"]:
103
+ raise ValueError("Padding must be 'center' or 'same'.")
104
+ self.padding = padding
105
+ self.n_fft = n_fft
106
+ self.hop_length = hop_length
107
+ self.win_length = win_length
108
+ window = torch.hann_window(win_length)
109
+ self.register_buffer("window", window)
110
+
111
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
112
+ """
113
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
114
+
115
+ Args:
116
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
117
+ N is the number of frequency bins, and T is the number of time frames.
118
+
119
+ Returns:
120
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
121
+ """
122
+ if self.padding == "center":
123
+ # Fallback to pytorch native implementation
124
+ return torch.istft(
125
+ spec,
126
+ self.n_fft,
127
+ self.hop_length,
128
+ self.win_length,
129
+ self.window,
130
+ center=True,
131
+ )
132
+ elif self.padding == "same":
133
+ pad = (self.win_length - self.hop_length) // 2
134
+ else:
135
+ raise ValueError("Padding must be 'center' or 'same'.")
136
+
137
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
138
+ B, N, T = spec.shape
139
+
140
+ # Inverse FFT
141
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
142
+ ifft = ifft * self.window[None, :, None]
143
+
144
+ # Overlap and Add
145
+ output_size = (T - 1) * self.hop_length + self.win_length
146
+ y = torch.nn.functional.fold(
147
+ ifft,
148
+ output_size=(1, output_size),
149
+ kernel_size=(1, self.win_length),
150
+ stride=(1, self.hop_length),
151
+ )[:, 0, 0, pad:-pad]
152
+
153
+ # Window envelope
154
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
155
+ window_envelope = torch.nn.functional.fold(
156
+ window_sq,
157
+ output_size=(1, output_size),
158
+ kernel_size=(1, self.win_length),
159
+ stride=(1, self.hop_length),
160
+ ).squeeze()[pad:-pad]
161
+
162
+ # Normalize
163
+ assert (window_envelope > 1e-11).all()
164
+ y = y / window_envelope
165
+
166
+ return y
167
+
168
+
169
+ class MDCT(nn.Module):
170
+ """
171
+ Modified Discrete Cosine Transform (MDCT) module.
172
+
173
+ Args:
174
+ frame_len (int): Length of the MDCT frame.
175
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
176
+ """
177
+
178
+ def __init__(self, frame_len: int, padding: str = "same"):
179
+ super().__init__()
180
+ if padding not in ["center", "same"]:
181
+ raise ValueError("Padding must be 'center' or 'same'.")
182
+ self.padding = padding
183
+ self.frame_len = frame_len
184
+ N = frame_len // 2
185
+ n0 = (N + 1) / 2
186
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
187
+ self.register_buffer("window", window)
188
+
189
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
190
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
191
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
192
+ # https://github.com/pytorch/pytorch/issues/71613
193
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
194
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
195
+
196
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
197
+ """
198
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
199
+
200
+ Args:
201
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
202
+ and T is the length of the audio.
203
+
204
+ Returns:
205
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
206
+ and N is the number of frequency bins.
207
+ """
208
+ if self.padding == "center":
209
+ audio = torch.nn.functional.pad(
210
+ audio, (self.frame_len // 2, self.frame_len // 2)
211
+ )
212
+ elif self.padding == "same":
213
+ # hop_length is 1/2 frame_len
214
+ audio = torch.nn.functional.pad(
215
+ audio, (self.frame_len // 4, self.frame_len // 4)
216
+ )
217
+ else:
218
+ raise ValueError("Padding must be 'center' or 'same'.")
219
+
220
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
221
+ N = self.frame_len // 2
222
+ x = x * self.window.expand(x.shape)
223
+ X = torch.fft.fft(
224
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
225
+ )[..., :N]
226
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
227
+ return torch.real(res) * np.sqrt(2)
228
+
229
+
230
+ class IMDCT(nn.Module):
231
+ """
232
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
233
+
234
+ Args:
235
+ frame_len (int): Length of the MDCT frame.
236
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
237
+ """
238
+
239
+ def __init__(self, frame_len: int, padding: str = "same"):
240
+ super().__init__()
241
+ if padding not in ["center", "same"]:
242
+ raise ValueError("Padding must be 'center' or 'same'.")
243
+ self.padding = padding
244
+ self.frame_len = frame_len
245
+ N = frame_len // 2
246
+ n0 = (N + 1) / 2
247
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
248
+ self.register_buffer("window", window)
249
+
250
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
251
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
252
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
253
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
254
+
255
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
256
+ """
257
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
258
+
259
+ Args:
260
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
261
+ L is the number of frames, and N is the number of frequency bins.
262
+
263
+ Returns:
264
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
265
+ """
266
+ B, L, N = X.shape
267
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
268
+ Y[..., :N] = X
269
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
270
+ y = torch.fft.ifft(
271
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
272
+ )
273
+ y = (
274
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
275
+ * np.sqrt(N)
276
+ * np.sqrt(2)
277
+ )
278
+ result = y * self.window.expand(y.shape)
279
+ output_size = (1, (L + 1) * N)
280
+ audio = torch.nn.functional.fold(
281
+ result.transpose(1, 2),
282
+ output_size=output_size,
283
+ kernel_size=(1, self.frame_len),
284
+ stride=(1, self.frame_len // 2),
285
+ )[:, 0, 0, :]
286
+
287
+ if self.padding == "center":
288
+ pad = self.frame_len // 2
289
+ elif self.padding == "same":
290
+ pad = self.frame_len // 4
291
+ else:
292
+ raise ValueError("Padding must be 'center' or 'same'.")
293
+
294
+ audio = audio[:, pad:-pad]
295
+ return audio
296
+
297
+
298
+ class FourierHead(nn.Module):
299
+ """Base class for inverse fourier modules."""
300
+
301
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
302
+ """
303
+ Args:
304
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
305
+ L is the sequence length, and H denotes the model dimension.
306
+
307
+ Returns:
308
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
309
+ """
310
+ raise NotImplementedError("Subclasses must implement the forward method.")
311
+
312
+
313
+ class ISTFTHead(FourierHead):
314
+ """
315
+ ISTFT Head module for predicting STFT complex coefficients.
316
+
317
+ Args:
318
+ dim (int): Hidden dimension of the model.
319
+ n_fft (int): Size of Fourier transform.
320
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
321
+ the resolution of the input features.
322
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
323
+ """
324
+
325
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
326
+ super().__init__()
327
+ out_dim = n_fft + 2
328
+ self.out = torch.nn.Linear(dim, out_dim)
329
+ self.istft = ISTFT(
330
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
331
+ )
332
+
333
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
334
+ """
335
+ Forward pass of the ISTFTHead module.
336
+
337
+ Args:
338
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
339
+ L is the sequence length, and H denotes the model dimension.
340
+
341
+ Returns:
342
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
343
+ """
344
+ x = self.out(x).transpose(1, 2)
345
+ mag, p = x.chunk(2, dim=1)
346
+ mag = torch.exp(mag)
347
+ mag = torch.clip(
348
+ mag, max=1e2
349
+ ) # safeguard to prevent excessively large magnitudes
350
+ # wrapping happens here. These two lines produce real and imaginary value
351
+ x = torch.cos(p)
352
+ y = torch.sin(p)
353
+ # recalculating phase here does not produce anything new
354
+ # only costs time
355
+ # phase = torch.atan2(y, x)
356
+ # S = mag * torch.exp(phase * 1j)
357
+ # better directly produce the complex value
358
+ S = mag * (x + 1j * y)
359
+ audio = self.istft(S)
360
+ return audio
361
+
362
+
363
+ class IMDCTSymExpHead(FourierHead):
364
+ """
365
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
366
+
367
+ Args:
368
+ dim (int): Hidden dimension of the model.
369
+ mdct_frame_len (int): Length of the MDCT frame.
370
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
371
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
372
+ based on perceptual scaling. Defaults to None.
373
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ dim: int,
379
+ mdct_frame_len: int,
380
+ padding: str = "same",
381
+ sample_rate: Optional[int] = None,
382
+ clip_audio: bool = False,
383
+ ):
384
+ super().__init__()
385
+ out_dim = mdct_frame_len // 2
386
+ self.out = nn.Linear(dim, out_dim)
387
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
388
+ self.clip_audio = clip_audio
389
+
390
+ if sample_rate is not None:
391
+ # optionally init the last layer following mel-scale
392
+ m_max = _hz_to_mel(sample_rate // 2)
393
+ m_pts = torch.linspace(0, m_max, out_dim)
394
+ f_pts = _mel_to_hz(m_pts)
395
+ scale = 1 - (f_pts / f_pts.max())
396
+
397
+ with torch.no_grad():
398
+ self.out.weight.mul_(scale.view(-1, 1))
399
+
400
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
401
+ """
402
+ Forward pass of the IMDCTSymExpHead module.
403
+
404
+ Args:
405
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
406
+ L is the sequence length, and H denotes the model dimension.
407
+
408
+ Returns:
409
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
410
+ """
411
+ x = self.out(x)
412
+ x = symexp(x)
413
+ x = torch.clip(
414
+ x, min=-1e2, max=1e2
415
+ ) # safeguard to prevent excessively large magnitudes
416
+ audio = self.imdct(x)
417
+ if self.clip_audio:
418
+ audio = torch.clip(x, min=-1.0, max=1.0)
419
+
420
+ return audio
421
+
422
+
423
+ class IMDCTCosHead(FourierHead):
424
+ """
425
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
426
+
427
+ Args:
428
+ dim (int): Hidden dimension of the model.
429
+ mdct_frame_len (int): Length of the MDCT frame.
430
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
431
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ dim: int,
437
+ mdct_frame_len: int,
438
+ padding: str = "same",
439
+ clip_audio: bool = False,
440
+ ):
441
+ super().__init__()
442
+ self.clip_audio = clip_audio
443
+ self.out = nn.Linear(dim, mdct_frame_len)
444
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
445
+
446
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
447
+ """
448
+ Forward pass of the IMDCTCosHead module.
449
+
450
+ Args:
451
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
452
+ L is the sequence length, and H denotes the model dimension.
453
+
454
+ Returns:
455
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
456
+ """
457
+ x = self.out(x)
458
+ m, p = x.chunk(2, dim=2)
459
+ m = torch.exp(m).clip(
460
+ max=1e2
461
+ ) # safeguard to prevent excessively large magnitudes
462
+ audio = self.imdct(m * torch.cos(p))
463
+ if self.clip_audio:
464
+ audio = torch.clip(x, min=-1.0, max=1.0)
465
+ return audio
466
+
467
+
468
+ class ConvNeXtBlock(nn.Module):
469
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
470
+
471
+ Args:
472
+ dim (int): Number of input channels.
473
+ intermediate_dim (int): Dimensionality of the intermediate layer.
474
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
475
+ Defaults to None.
476
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
477
+ None means non-conditional LayerNorm. Defaults to None.
478
+ """
479
+
480
+ def __init__(
481
+ self,
482
+ dim: int,
483
+ intermediate_dim: int,
484
+ layer_scale_init_value: float,
485
+ adanorm_num_embeddings: Optional[int] = None,
486
+ ):
487
+ super().__init__()
488
+ self.dwconv = nn.Conv1d(
489
+ dim, dim, kernel_size=7, padding=3, groups=dim
490
+ ) # depthwise conv
491
+ self.adanorm = adanorm_num_embeddings is not None
492
+ if adanorm_num_embeddings:
493
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
494
+ else:
495
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
496
+ self.pwconv1 = nn.Linear(
497
+ dim, intermediate_dim
498
+ ) # pointwise/1x1 convs, implemented with linear layers
499
+ self.act = nn.GELU()
500
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
501
+ self.gamma = (
502
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
503
+ if layer_scale_init_value > 0
504
+ else None
505
+ )
506
+
507
+ def forward(
508
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
509
+ ) -> torch.Tensor:
510
+ residual = x
511
+ x = self.dwconv(x)
512
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
513
+ if self.adanorm:
514
+ assert cond_embedding_id is not None
515
+ x = self.norm(x, cond_embedding_id)
516
+ else:
517
+ x = self.norm(x)
518
+ x = self.pwconv1(x)
519
+ x = self.act(x)
520
+ x = self.pwconv2(x)
521
+ if self.gamma is not None:
522
+ x = self.gamma * x
523
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
524
+
525
+ x = residual + x
526
+ return x
527
+
528
+
529
+ class AdaLayerNorm(nn.Module):
530
+ """
531
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
532
+
533
+ Args:
534
+ num_embeddings (int): Number of embeddings.
535
+ embedding_dim (int): Dimension of the embeddings.
536
+ """
537
+
538
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
539
+ super().__init__()
540
+ self.eps = eps
541
+ self.dim = embedding_dim
542
+ self.scale = nn.Embedding(
543
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
544
+ )
545
+ self.shift = nn.Embedding(
546
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
547
+ )
548
+ torch.nn.init.ones_(self.scale.weight)
549
+ torch.nn.init.zeros_(self.shift.weight)
550
+
551
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
552
+ scale = self.scale(cond_embedding_id)
553
+ shift = self.shift(cond_embedding_id)
554
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
555
+ x = x * scale + shift
556
+ return x
557
+
558
+
559
+ class ResBlock1(nn.Module):
560
+ """
561
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
562
+ but without upsampling layers.
563
+
564
+ Args:
565
+ dim (int): Number of input channels.
566
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
567
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
568
+ Defaults to (1, 3, 5).
569
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
570
+ Defaults to 0.1.
571
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
572
+ Defaults to None.
573
+ """
574
+
575
+ def __init__(
576
+ self,
577
+ dim: int,
578
+ kernel_size: int = 3,
579
+ dilation: Tuple[int, int, int] = (1, 3, 5),
580
+ lrelu_slope: float = 0.1,
581
+ layer_scale_init_value: Optional[float] = None,
582
+ ):
583
+ super().__init__()
584
+ self.lrelu_slope = lrelu_slope
585
+ self.convs1 = nn.ModuleList(
586
+ [
587
+ weight_norm(
588
+ nn.Conv1d(
589
+ dim,
590
+ dim,
591
+ kernel_size,
592
+ 1,
593
+ dilation=dilation[0],
594
+ padding=self.get_padding(kernel_size, dilation[0]),
595
+ )
596
+ ),
597
+ weight_norm(
598
+ nn.Conv1d(
599
+ dim,
600
+ dim,
601
+ kernel_size,
602
+ 1,
603
+ dilation=dilation[1],
604
+ padding=self.get_padding(kernel_size, dilation[1]),
605
+ )
606
+ ),
607
+ weight_norm(
608
+ nn.Conv1d(
609
+ dim,
610
+ dim,
611
+ kernel_size,
612
+ 1,
613
+ dilation=dilation[2],
614
+ padding=self.get_padding(kernel_size, dilation[2]),
615
+ )
616
+ ),
617
+ ]
618
+ )
619
+
620
+ self.convs2 = nn.ModuleList(
621
+ [
622
+ weight_norm(
623
+ nn.Conv1d(
624
+ dim,
625
+ dim,
626
+ kernel_size,
627
+ 1,
628
+ dilation=1,
629
+ padding=self.get_padding(kernel_size, 1),
630
+ )
631
+ ),
632
+ weight_norm(
633
+ nn.Conv1d(
634
+ dim,
635
+ dim,
636
+ kernel_size,
637
+ 1,
638
+ dilation=1,
639
+ padding=self.get_padding(kernel_size, 1),
640
+ )
641
+ ),
642
+ weight_norm(
643
+ nn.Conv1d(
644
+ dim,
645
+ dim,
646
+ kernel_size,
647
+ 1,
648
+ dilation=1,
649
+ padding=self.get_padding(kernel_size, 1),
650
+ )
651
+ ),
652
+ ]
653
+ )
654
+
655
+ self.gamma = nn.ParameterList(
656
+ [
657
+ (
658
+ nn.Parameter(
659
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
660
+ )
661
+ if layer_scale_init_value is not None
662
+ else None
663
+ ),
664
+ (
665
+ nn.Parameter(
666
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
667
+ )
668
+ if layer_scale_init_value is not None
669
+ else None
670
+ ),
671
+ (
672
+ nn.Parameter(
673
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
674
+ )
675
+ if layer_scale_init_value is not None
676
+ else None
677
+ ),
678
+ ]
679
+ )
680
+
681
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
682
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
683
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
684
+ xt = c1(xt)
685
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
686
+ xt = c2(xt)
687
+ if gamma is not None:
688
+ xt = gamma * xt
689
+ x = xt + x
690
+ return x
691
+
692
+ def remove_weight_norm(self):
693
+ for l in self.convs1:
694
+ remove_weight_norm(l)
695
+ for l in self.convs2:
696
+ remove_weight_norm(l)
697
+
698
+ @staticmethod
699
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
700
+ return int((kernel_size * dilation - dilation) / 2)
701
+
702
+
703
+ class Backbone(nn.Module):
704
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
705
+
706
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
707
+ """
708
+ Args:
709
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
710
+ C denotes output features, and L is the sequence length.
711
+
712
+ Returns:
713
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
714
+ and H denotes the model dimension.
715
+ """
716
+ raise NotImplementedError("Subclasses must implement the forward method.")
717
+
718
+
719
+ class VocosBackbone(Backbone):
720
+ """
721
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
722
+
723
+ Args:
724
+ input_channels (int): Number of input features channels.
725
+ dim (int): Hidden dimension of the model.
726
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
727
+ num_layers (int): Number of ConvNeXtBlock layers.
728
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
729
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
730
+ None means non-conditional model. Defaults to None.
731
+ """
732
+
733
+ def __init__(
734
+ self,
735
+ input_channels: int,
736
+ dim: int,
737
+ intermediate_dim: int,
738
+ num_layers: int,
739
+ layer_scale_init_value: Optional[float] = None,
740
+ adanorm_num_embeddings: Optional[int] = None,
741
+ ):
742
+ super().__init__()
743
+ self.input_channels = input_channels
744
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
745
+ self.adanorm = adanorm_num_embeddings is not None
746
+ if adanorm_num_embeddings:
747
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
748
+ else:
749
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
750
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
751
+ self.convnext = nn.ModuleList(
752
+ [
753
+ ConvNeXtBlock(
754
+ dim=dim,
755
+ intermediate_dim=intermediate_dim,
756
+ layer_scale_init_value=layer_scale_init_value,
757
+ adanorm_num_embeddings=adanorm_num_embeddings,
758
+ )
759
+ for _ in range(num_layers)
760
+ ]
761
+ )
762
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
763
+ self.apply(self._init_weights)
764
+
765
+ def _init_weights(self, m):
766
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
767
+ nn.init.trunc_normal_(m.weight, std=0.02)
768
+ nn.init.constant_(m.bias, 0)
769
+
770
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
771
+ bandwidth_id = kwargs.get("bandwidth_id", None)
772
+ x = self.embed(x)
773
+ if self.adanorm:
774
+ assert bandwidth_id is not None
775
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
776
+ else:
777
+ x = self.norm(x.transpose(1, 2))
778
+ x = x.transpose(1, 2)
779
+ for conv_block in self.convnext:
780
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
781
+ x = self.final_layer_norm(x.transpose(1, 2))
782
+ return x
783
+
784
+
785
+ class VocosResNetBackbone(Backbone):
786
+ """
787
+ Vocos backbone module built with ResBlocks.
788
+
789
+ Args:
790
+ input_channels (int): Number of input features channels.
791
+ dim (int): Hidden dimension of the model.
792
+ num_blocks (int): Number of ResBlock1 blocks.
793
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
794
+ """
795
+
796
+ def __init__(
797
+ self,
798
+ input_channels,
799
+ dim,
800
+ num_blocks,
801
+ layer_scale_init_value=None,
802
+ ):
803
+ super().__init__()
804
+ self.input_channels = input_channels
805
+ self.embed = weight_norm(
806
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
807
+ )
808
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
809
+ self.resnet = nn.Sequential(
810
+ *[
811
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
812
+ for _ in range(num_blocks)
813
+ ]
814
+ )
815
+
816
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
817
+ x = self.embed(x)
818
+ x = self.resnet(x)
819
+ x = x.transpose(1, 2)
820
+ return x
821
+
822
+
823
+ class Vocos(nn.Module):
824
+ def __init__(
825
+ self,
826
+ input_channels: int = 256,
827
+ dim: int = 384,
828
+ intermediate_dim: int = 1152,
829
+ num_layers: int = 8,
830
+ adanorm_num_embeddings: int = 4,
831
+ n_fft: int = 800,
832
+ hop_size: int = 200,
833
+ padding: str = "same",
834
+ ):
835
+ super().__init__()
836
+
837
+ self.backbone = VocosBackbone(
838
+ input_channels=input_channels,
839
+ dim=dim,
840
+ intermediate_dim=intermediate_dim,
841
+ num_layers=num_layers,
842
+ adanorm_num_embeddings=adanorm_num_embeddings,
843
+ )
844
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
845
+
846
+ def forward(self, x):
847
+ x = self.backbone(x)
848
+ x = self.head(x)
849
+
850
+ return x[:, None, :]
models/codec/ns3_codec/README.md ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3
2
+
3
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2403.03100.pdf)
4
+ [![demo](https://img.shields.io/badge/FACodec-Demo-red)](https://speechresearch.github.io/naturalspeech3/)
5
+ [![model](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/naturalspeech3_facodec)
6
+ [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)
7
+
8
+ ## Overview
9
+
10
+ FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation.
11
+
12
+ Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E).
13
+
14
+ <br>
15
+ <div align="center">
16
+ <img src="../../../imgs/ns3/ns3_overview.png" width="65%">
17
+ </div>
18
+ <br>
19
+
20
+ <br>
21
+ <div align="center">
22
+ <img src="../../../imgs/ns3/ns3_facodec.png" width="100%">
23
+ </div>
24
+ <br>
25
+
26
+ ## Useage
27
+
28
+ Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec)
29
+
30
+ Install Amphion
31
+ ```bash
32
+ git clone https://github.com/open-mmlab/Amphion.git
33
+ ```
34
+
35
+ Few lines of code to use the pre-trained FACodec model
36
+ ```python
37
+ from Amphion.models.codec.ns3_codec import FACodecEncoder, FACodecDecoder
38
+ from huggingface_hub import hf_hub_download
39
+
40
+ fa_encoder = FACodecEncoder(
41
+ ngf=32,
42
+ up_ratios=[2, 4, 5, 5],
43
+ out_channels=256,
44
+ )
45
+
46
+ fa_decoder = FACodecDecoder(
47
+ in_channels=256,
48
+ upsample_initial_channel=1024,
49
+ ngf=32,
50
+ up_ratios=[5, 5, 4, 2],
51
+ vq_num_q_c=2,
52
+ vq_num_q_p=1,
53
+ vq_num_q_r=3,
54
+ vq_dim=256,
55
+ codebook_dim=8,
56
+ codebook_size_prosody=10,
57
+ codebook_size_content=10,
58
+ codebook_size_residual=10,
59
+ use_gr_x_timbre=True,
60
+ use_gr_residual_f0=True,
61
+ use_gr_residual_phone=True,
62
+ )
63
+
64
+ encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
65
+ decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")
66
+
67
+ fa_encoder.load_state_dict(torch.load(encoder_ckpt))
68
+ fa_decoder.load_state_dict(torch.load(decoder_ckpt))
69
+
70
+ fa_encoder.eval()
71
+ fa_decoder.eval()
72
+
73
+ ```
74
+
75
+ Inference
76
+ ```python
77
+ test_wav_path = "test.wav"
78
+ test_wav = librosa.load(test_wav_path, sr=16000)[0]
79
+ test_wav = torch.from_numpy(test_wav).float()
80
+ test_wav = test_wav.unsqueeze(0).unsqueeze(0)
81
+
82
+ with torch.no_grad():
83
+
84
+ # encode
85
+ enc_out = fa_encoder(test_wav)
86
+ print(enc_out.shape)
87
+
88
+ # quantize
89
+ vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
90
+
91
+ # latent after quantization
92
+ print(vq_post_emb.shape)
93
+
94
+ # codes
95
+ print("vq id shape:", vq_id.shape)
96
+
97
+ # get prosody code
98
+ prosody_code = vq_id[:1]
99
+ print("prosody code shape:", prosody_code.shape)
100
+
101
+ # get content code
102
+ cotent_code = vq_id[1:3]
103
+ print("content code shape:", cotent_code.shape)
104
+
105
+ # get residual code (acoustic detail codes)
106
+ residual_code = vq_id[3:]
107
+ print("residual code shape:", residual_code.shape)
108
+
109
+ # speaker embedding
110
+ print("speaker embedding shape:", spk_embs.shape)
111
+
112
+ # decode (recommand)
113
+ recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
114
+ print(recon_wav.shape)
115
+ sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)
116
+ ```
117
+
118
+ FACodec can achieve zero-shot voice conversion with FACodecEncoderV2/FACodecDecoderV2 or FACodecRedecoder
119
+ ```python
120
+ from Amphion.models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2
121
+
122
+ # Same parameters as FACodecEncoder/FACodecDecoder
123
+ fa_encoder_v2 = FACodecEncoderV2(...)
124
+ fa_decoder_v2 = FACodecDecoderV2(...)
125
+
126
+ encoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder_v2.bin")
127
+ decoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder_v2.bin")
128
+
129
+ fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt))
130
+ fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt))
131
+
132
+ with torch.no_grad():
133
+ enc_out_a = fa_encoder_v2(wav_a)
134
+ prosody_a = fa_encoder_v2.get_prosody_feature(wav_a)
135
+ enc_out_b = fa_encoder_v2(wav_b)
136
+ prosody_b = fa_encoder_v2.get_prosody_feature(wav_b)
137
+
138
+ vq_post_emb_a, vq_id_a, _, quantized, spk_embs_a = fa_decoder_v2(
139
+ enc_out_a, prosody_a, eval_vq=False, vq=True
140
+ )
141
+ vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2(
142
+ enc_out_b, prosody_b, eval_vq=False, vq=True
143
+ )
144
+
145
+ vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False)
146
+ recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b)
147
+ ```
148
+
149
+ or
150
+
151
+ ```python
152
+ from Amphion.models.codec.ns3_codec import FACodecRedecoder
153
+
154
+ fa_redecoder = FACodecRedecoder()
155
+
156
+ redecoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_redecoder.bin")
157
+
158
+ fa_redecoder.load_state_dict(torch.load(redecoder_ckpt))
159
+
160
+ with torch.no_grad():
161
+ enc_out_a = fa_encoder(wav_a)
162
+ enc_out_b = fa_encoder(wav_b)
163
+
164
+ vq_post_emb_a, vq_id_a, _, quantized_a, spk_embs_a = fa_decoder(enc_out_a, eval_vq=False, vq=True)
165
+ vq_post_emb_b, vq_id_b, _, quantized_b, spk_embs_b = fa_decoder(enc_out_b, eval_vq=False, vq=True)
166
+
167
+ # convert speaker
168
+ vq_post_emb_a_to_b = fa_redecoder.vq2emb(vq_id_a, spk_embs_b, use_residual=False)
169
+ recon_wav_a_to_b = fa_redecoder.inference(vq_post_emb_a_to_b, spk_embs_b)
170
+
171
+ sf.write("recon_a_to_b.wav", recon_wav_a_to_b[0][0].cpu().numpy(), 16000)
172
+ ```
173
+
174
+ ## Q&A
175
+
176
+ Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame?
177
+
178
+ A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame.
179
+
180
+ Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec?
181
+
182
+ A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes.
183
+
184
+ Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec?
185
+
186
+ A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model.
187
+
188
+ Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc.
189
+
190
+ A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio.
191
+
192
+ Q5: Can FACodec be used for content feature for some other tasks like voice conversion?
193
+
194
+ A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction.
195
+
196
+ ## Citations
197
+
198
+ If you use our FACodec model, please cite the following paper:
199
+
200
+ ```bibtex
201
+ @article{ju2024naturalspeech,
202
+ title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models},
203
+ author={Ju, Zeqian and Wang, Yuancheng and Shen, Kai and Tan, Xu and Xin, Detai and Yang, Dongchao and Liu, Yanqing and Leng, Yichong and Song, Kaitao and Tang, Siliang and others},
204
+ journal={arXiv preprint arXiv:2403.03100},
205
+ year={2024}
206
+ }
207
+
208
+ @article{zhang2023amphion,
209
+ title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
210
+ author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu},
211
+ journal={arXiv},
212
+ year={2024},
213
+ volume={abs/2312.09911}
214
+ }
215
+ ```
216
+
models/codec/ns3_codec/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .facodec import *
models/codec/ns3_codec/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
models/codec/ns3_codec/alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
models/codec/ns3_codec/alias_free_torch/filter.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ def sinc(x: torch.Tensor):
14
+ """
15
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17
+ """
18
+ return torch.where(
19
+ x == 0,
20
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
21
+ torch.sin(math.pi * x) / math.pi / x,
22
+ )
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ def kaiser_sinc_filter1d(
28
+ cutoff, half_width, kernel_size
29
+ ): # return filter [1,1,kernel_size]
30
+ even = kernel_size % 2 == 0
31
+ half_size = kernel_size // 2
32
+
33
+ # For kaiser window
34
+ delta_f = 4 * half_width
35
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36
+ if A > 50.0:
37
+ beta = 0.1102 * (A - 8.7)
38
+ elif A >= 21.0:
39
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40
+ else:
41
+ beta = 0.0
42
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43
+
44
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45
+ if even:
46
+ time = torch.arange(-half_size, half_size) + 0.5
47
+ else:
48
+ time = torch.arange(kernel_size) - half_size
49
+ if cutoff == 0:
50
+ filter_ = torch.zeros_like(time)
51
+ else:
52
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
54
+ # of the constant component in the input signal.
55
+ filter_ /= filter_.sum()
56
+ filter = filter_.view(1, 1, kernel_size)
57
+
58
+ return filter
59
+
60
+
61
+ class LowPassFilter1d(nn.Module):
62
+ def __init__(
63
+ self,
64
+ cutoff=0.5,
65
+ half_width=0.6,
66
+ stride: int = 1,
67
+ padding: bool = True,
68
+ padding_mode: str = "replicate",
69
+ kernel_size: int = 12,
70
+ ):
71
+ # kernel_size should be even number for stylegan3 setup,
72
+ # in this implementation, odd number is also possible.
73
+ super().__init__()
74
+ if cutoff < -0.0:
75
+ raise ValueError("Minimum cutoff must be larger than zero.")
76
+ if cutoff > 0.5:
77
+ raise ValueError("A cutoff above 0.5 does not make sense.")
78
+ self.kernel_size = kernel_size
79
+ self.even = kernel_size % 2 == 0
80
+ self.pad_left = kernel_size // 2 - int(self.even)
81
+ self.pad_right = kernel_size // 2
82
+ self.stride = stride
83
+ self.padding = padding
84
+ self.padding_mode = padding_mode
85
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86
+ self.register_buffer("filter", filter)
87
+
88
+ # input [B, C, T]
89
+ def forward(self, x):
90
+ _, C, _ = x.shape
91
+
92
+ if self.padding:
93
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95
+
96
+ return out
models/codec/ns3_codec/alias_free_torch/resample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .filter import LowPassFilter1d
6
+ from .filter import kaiser_sinc_filter1d
7
+
8
+
9
+ class UpSample1d(nn.Module):
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left : -self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = (
45
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46
+ )
47
+ self.lowpass = LowPassFilter1d(
48
+ cutoff=0.5 / ratio,
49
+ half_width=0.6 / ratio,
50
+ stride=ratio,
51
+ kernel_size=self.kernel_size,
52
+ )
53
+
54
+ def forward(self, x):
55
+ xx = self.lowpass(x)
56
+
57
+ return xx