Spaces:
Runtime error
Runtime error
Hecheng0625
commited on
Commit
•
c968fc3
1
Parent(s):
8c92a11
Upload 409 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- README.md +169 -14
- models/__init__.py +0 -0
- models/base/__init__.py +7 -0
- models/base/base_dataset.py +464 -0
- models/base/base_inference.py +220 -0
- models/base/base_sampler.py +157 -0
- models/base/base_trainer.py +348 -0
- models/base/new_dataset.py +50 -0
- models/base/new_inference.py +253 -0
- models/base/new_trainer.py +727 -0
- models/codec/__init__.py +0 -0
- models/codec/amphion_codec/codec.py +427 -0
- models/codec/amphion_codec/quantize/__init__.py +11 -0
- models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- models/codec/amphion_codec/vocos.py +881 -0
- models/codec/codec_dataset.py +264 -0
- models/codec/codec_inference.py +515 -0
- models/codec/codec_sampler.py +126 -0
- models/codec/codec_trainer.py +166 -0
- models/codec/facodec/__init__.py +0 -0
- models/codec/facodec/alias_free_torch/__init__.py +5 -0
- models/codec/facodec/alias_free_torch/act.py +29 -0
- models/codec/facodec/alias_free_torch/filter.py +96 -0
- models/codec/facodec/alias_free_torch/resample.py +57 -0
- models/codec/facodec/facodec_dataset.py +98 -0
- models/codec/facodec/facodec_inference.py +137 -0
- models/codec/facodec/facodec_trainer.py +776 -0
- models/codec/facodec/modules/JDC/__init__.py +1 -0
- models/codec/facodec/modules/JDC/bst.t7 +3 -0
- models/codec/facodec/modules/JDC/model.py +219 -0
- models/codec/facodec/modules/attentions.py +437 -0
- models/codec/facodec/modules/commons.py +331 -0
- models/codec/facodec/modules/gradient_reversal.py +35 -0
- models/codec/facodec/modules/layers.py +460 -0
- models/codec/facodec/modules/quantize.py +741 -0
- models/codec/facodec/modules/style_encoder.py +110 -0
- models/codec/facodec/modules/wavenet.py +224 -0
- models/codec/facodec/optimizer.py +104 -0
- models/codec/kmeans/repcodec_model.py +210 -0
- models/codec/kmeans/vocos.py +850 -0
- models/codec/ns3_codec/README.md +216 -0
- models/codec/ns3_codec/__init__.py +6 -0
- models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
.gitattributes
CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
imgs/vocoder/gan/MSSBCQTD.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
imgs/vocoder/gan/MSSBCQTD.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
models/codec/facodec/modules/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models/tts/maskgct/g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
|
39 |
+
models/tts/maskgct/wav/prompt.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,14 +1,169 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Amphion: An Open-Source Audio, Music, and Speech Generation Toolkit
|
2 |
+
|
3 |
+
<div>
|
4 |
+
<a href="https://arxiv.org/abs/2312.09911"><img src="https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg"></a>
|
5 |
+
<a href="https://huggingface.co/amphion"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink"></a>
|
6 |
+
<a href="https://openxlab.org.cn/usercenter/Amphion"><img src="https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg"></a>
|
7 |
+
<a href="https://discord.com/invite/ZxxREr3Y"><img src="https://img.shields.io/badge/Discord-Join%20chat-blue.svg"></a>
|
8 |
+
<a href="egs/tts/README.md"><img src="https://img.shields.io/badge/README-TTS-blue"></a>
|
9 |
+
<a href="egs/svc/README.md"><img src="https://img.shields.io/badge/README-SVC-blue"></a>
|
10 |
+
<a href="egs/tta/README.md"><img src="https://img.shields.io/badge/README-TTA-blue"></a>
|
11 |
+
<a href="egs/vocoder/README.md"><img src="https://img.shields.io/badge/README-Vocoder-purple"></a>
|
12 |
+
<a href="egs/metrics/README.md"><img src="https://img.shields.io/badge/README-Evaluation-yellow"></a>
|
13 |
+
<a href="LICENSE"><img src="https://img.shields.io/badge/LICENSE-MIT-red"></a>
|
14 |
+
</div>
|
15 |
+
<br>
|
16 |
+
|
17 |
+
**Amphion (/æmˈfaɪən/) is a toolkit for Audio, Music, and Speech Generation.** Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development. Amphion offers a unique feature: **visualizations** of classic models or architectures. We believe that these visualizations are beneficial for junior researchers and engineers who wish to gain a better understanding of the model.
|
18 |
+
|
19 |
+
**The North-Star objective of Amphion is to offer a platform for studying the conversion of any inputs into audio.** Amphion is designed to support individual generation tasks, including but not limited to,
|
20 |
+
|
21 |
+
- **TTS**: Text to Speech (⛳ supported)
|
22 |
+
- **SVS**: Singing Voice Synthesis (👨💻 developing)
|
23 |
+
- **VC**: Voice Conversion (👨💻 developing)
|
24 |
+
- **SVC**: Singing Voice Conversion (⛳ supported)
|
25 |
+
- **TTA**: Text to Audio (⛳ supported)
|
26 |
+
- **TTM**: Text to Music (👨💻 developing)
|
27 |
+
- more…
|
28 |
+
|
29 |
+
In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis.
|
30 |
+
|
31 |
+
## 🚀 News
|
32 |
+
- **2024/10/19**: We release **MaskGCT**, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision. MaskGCT is trained on Emilia dataset and achieves SOTA zero-shot TTS perfermance. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/tts/maskgct/README.md)
|
33 |
+
- **2024/09/01**: [Amphion](https://arxiv.org/abs/2312.09911), [Emilia](https://arxiv.org/abs/2407.05361) and [DSFF-SVC](https://arxiv.org/abs/2310.11160) got accepted by IEEE SLT 2024! 🤗
|
34 |
+
- **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.gg/drhW7ajqAG) to stay connected and engage with our community!
|
35 |
+
- **2024/08/20**: [SingVisio](https://arxiv.org/abs/2402.12660) got accepted by Computers & Graphics, [available here](https://www.sciencedirect.com/science/article/pii/S0097849324001936)! 🎉
|
36 |
+
- **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)! 👑👑👑
|
37 |
+
- **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](preprocessors/Emilia/README.md)
|
38 |
+
- **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md)
|
39 |
+
- **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md)
|
40 |
+
- **2024/02/22**: The first Amphion visualization tool, **SingVisio**, release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/visualization/SingVisio/README.md)
|
41 |
+
- **2023/12/18**: Amphion v0.1 release. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2312.09911) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink)](https://huggingface.co/amphion) [![youtube](https://img.shields.io/badge/YouTube-Demo-red)](https://www.youtube.com/watch?v=1aw0HhcggvQ) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/39)
|
42 |
+
- **2023/11/28**: Amphion alpha release. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/pull/2)
|
43 |
+
|
44 |
+
## ⭐ Key Features
|
45 |
+
|
46 |
+
### TTS: Text to Speech
|
47 |
+
|
48 |
+
- Amphion achieves state-of-the-art performance compared to existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures:
|
49 |
+
- [FastSpeech2](https://arxiv.org/abs/2006.04558): A non-autoregressive TTS architecture that utilizes feed-forward Transformer blocks.
|
50 |
+
- [VITS](https://arxiv.org/abs/2106.06103): An end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning
|
51 |
+
- [VALL-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes.
|
52 |
+
- [NaturalSpeech2](https://arxiv.org/abs/2304.09116): An architecture for TTS that utilizes a latent diffusion model to generate natural-sounding voices.
|
53 |
+
- [Jets](Jets): An end-to-end TTS model that jointly trains FastSpeech2 and HiFi-GAN with an alignment module.
|
54 |
+
- [MaskGCT](https://arxiv.org/abs/2409.00750): a fully non-autoregressive TTS architecture that eliminates the need for explicit alignment information between text and speech supervision.
|
55 |
+
|
56 |
+
### SVC: Singing Voice Conversion
|
57 |
+
|
58 |
+
- Ampion supports multiple content-based features from various pretrained models, including [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec). Their specific roles in SVC has been investigated in our SLT 2024 paper. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2310.11160) [![code](https://img.shields.io/badge/README-Code-red)](egs/svc/MultipleContentsSVC)
|
59 |
+
- Amphion implements several state-of-the-art model architectures, including diffusion-, transformer-, VAE- and flow-based models. The diffusion-based architecture uses [Bidirectional dilated CNN](https://openreview.net/pdf?id=a-xFK8Ymz5J) as a backend and supports several sampling algorithms such as [DDPM](https://arxiv.org/pdf/2006.11239.pdf), [DDIM](https://arxiv.org/pdf/2010.02502.pdf), and [PNDM](https://arxiv.org/pdf/2202.09778.pdf). Additionally, it supports single-step inference based on the [Consistency Model](https://openreview.net/pdf?id=FmqFfMTNnv).
|
60 |
+
|
61 |
+
### TTA: Text to Audio
|
62 |
+
|
63 |
+
- Amphion supports the TTA with a latent diffusion model. It is designed like [AudioLDM](https://arxiv.org/abs/2301.12503), [Make-an-Audio](https://arxiv.org/abs/2301.12661), and [AUDIT](https://arxiv.org/abs/2304.00830). It is also the official implementation of the text-to-audio generation part of our NeurIPS 2023 paper. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2304.00830) [![code](https://img.shields.io/badge/README-Code-red)](egs/tta/RECIPE.md)
|
64 |
+
|
65 |
+
### Vocoder
|
66 |
+
|
67 |
+
- Amphion supports various widely-used neural vocoders, including:
|
68 |
+
- GAN-based vocoders: [MelGAN](https://arxiv.org/abs/1910.06711), [HiFi-GAN](https://arxiv.org/abs/2010.05646), [NSF-HiFiGAN](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts), [BigVGAN](https://arxiv.org/abs/2206.04658), [APNet](https://arxiv.org/abs/2305.07952).
|
69 |
+
- Flow-based vocoders: [WaveGlow](https://arxiv.org/abs/1811.00002).
|
70 |
+
- Diffusion-based vocoders: [Diffwave](https://arxiv.org/abs/2009.09761).
|
71 |
+
- Auto-regressive based vocoders: [WaveNet](https://arxiv.org/abs/1609.03499), [WaveRNN](https://arxiv.org/abs/1802.08435v1).
|
72 |
+
- Amphion provides the official implementation of [Multi-Scale Constant-Q Transform Discriminator](https://arxiv.org/abs/2311.14957) (our ICASSP 2024 paper). It can be used to enhance any architecture GAN-based vocoders during training, and keep the inference stage (such as memory or speed) unchanged. [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2311.14957) [![code](https://img.shields.io/badge/README-Code-red)](egs/vocoder/gan/tfr_enhanced_hifigan)
|
73 |
+
|
74 |
+
### Evaluation
|
75 |
+
|
76 |
+
Amphion provides a comprehensive objective evaluation of the generated audio. The evaluation metrics contain:
|
77 |
+
|
78 |
+
- **F0 Modeling**: F0 Pearson Coefficients, F0 Periodicity Root Mean Square Error, F0 Root Mean Square Error, Voiced/Unvoiced F1 Score, etc.
|
79 |
+
- **Energy Modeling**: Energy Root Mean Square Error, Energy Pearson Coefficients, etc.
|
80 |
+
- **Intelligibility**: Character/Word Error Rate, which can be calculated based on [Whisper](https://github.com/openai/whisper) and more.
|
81 |
+
- **Spectrogram Distortion**: Frechet Audio Distance (FAD), Mel Cepstral Distortion (MCD), Multi-Resolution STFT Distance (MSTFT), Perceptual Evaluation of Speech Quality (PESQ), Short Time Objective Intelligibility (STOI), etc.
|
82 |
+
- **Speaker Similarity**: Cosine similarity, which can be calculated based on [RawNet3](https://github.com/Jungjee/RawNet), [Resemblyzer](https://github.com/resemble-ai/Resemblyzer), [WeSpeaker](https://github.com/wenet-e2e/wespeaker), [WavLM](https://github.com/microsoft/unilm/tree/master/wavlm) and more.
|
83 |
+
|
84 |
+
### Datasets
|
85 |
+
|
86 |
+
- Amphion unifies the data preprocess of the open-source datasets including [AudioCaps](https://audiocaps.github.io/), [LibriTTS](https://www.openslr.org/60/), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [M4Singer](https://github.com/M4Singer/M4Singer), [Opencpop](https://wenet.org.cn/opencpop/), [OpenSinger](https://github.com/Multi-Singer/Multi-Singer.github.io), [SVCC](http://vc-challenge.org/), [VCTK](https://datashare.ed.ac.uk/handle/10283/3443), and more. The supported dataset list can be seen [here](egs/datasets/README.md) (updating).
|
87 |
+
- Amphion (exclusively) supports the [**Emilia**](preprocessors/Emilia/README.md) dataset and its preprocessing pipeline **Emilia-Pipe** for in-the-wild speech data!
|
88 |
+
|
89 |
+
### Visualization
|
90 |
+
|
91 |
+
Amphion provides visualization tools to interactively illustrate the internal processing mechanism of classic models. This provides an invaluable resource for educational purposes and for facilitating understandable research.
|
92 |
+
|
93 |
+
Currently, Amphion supports [SingVisio](egs/visualization/SingVisio/README.md), a visualization tool of the diffusion model for singing voice conversion. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2402.12660) [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [![Video](https://img.shields.io/badge/Video-Demo-orange)](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view)
|
94 |
+
|
95 |
+
|
96 |
+
## 📀 Installation
|
97 |
+
|
98 |
+
Amphion can be installed through either Setup Installer or Docker Image.
|
99 |
+
|
100 |
+
### Setup Installer
|
101 |
+
|
102 |
+
```bash
|
103 |
+
git clone https://github.com/open-mmlab/Amphion.git
|
104 |
+
cd Amphion
|
105 |
+
|
106 |
+
# Install Python Environment
|
107 |
+
conda create --name amphion python=3.9.15
|
108 |
+
conda activate amphion
|
109 |
+
|
110 |
+
# Install Python Packages Dependencies
|
111 |
+
sh env.sh
|
112 |
+
```
|
113 |
+
|
114 |
+
### Docker Image
|
115 |
+
|
116 |
+
1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads).
|
117 |
+
|
118 |
+
2. Run the following commands:
|
119 |
+
```bash
|
120 |
+
git clone https://github.com/open-mmlab/Amphion.git
|
121 |
+
cd Amphion
|
122 |
+
|
123 |
+
docker pull realamphion/amphion
|
124 |
+
docker run --runtime=nvidia --gpus all -it -v .:/app realamphion/amphion
|
125 |
+
```
|
126 |
+
Mount dataset by argument `-v` is necessary when using Docker. Please refer to [Mount dataset in Docker container](egs/datasets/docker.md) and [Docker Docs](https://docs.docker.com/engine/reference/commandline/container_run/#volume) for more details.
|
127 |
+
|
128 |
+
|
129 |
+
## 🐍 Usage in Python
|
130 |
+
|
131 |
+
We detail the instructions of different tasks in the following recipes:
|
132 |
+
|
133 |
+
- [Text to Speech (TTS)](egs/tts/README.md)
|
134 |
+
- [Singing Voice Conversion (SVC)](egs/svc/README.md)
|
135 |
+
- [Text to Audio (TTA)](egs/tta/README.md)
|
136 |
+
- [Vocoder](egs/vocoder/README.md)
|
137 |
+
- [Evaluation](egs/metrics/README.md)
|
138 |
+
- [Visualization](egs/visualization/README.md)
|
139 |
+
|
140 |
+
## 👨💻 Contributing
|
141 |
+
We appreciate all contributions to improve Amphion. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
|
142 |
+
|
143 |
+
## 🙏 Acknowledgement
|
144 |
+
|
145 |
+
|
146 |
+
- [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) and [jaywalnut310's VITS](https://github.com/jaywalnut310/vits) for model architecture code.
|
147 |
+
- [lifeiteng's VALL-E](https://github.com/lifeiteng/vall-e) for training pipeline and model architecture design.
|
148 |
+
- [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) for semantic-distilled tokenizer design.
|
149 |
+
- [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), [ContentVec](https://github.com/auspicious3000/contentvec), and [RawNet3](https://github.com/Jungjee/RawNet) for pretrained models and inference code.
|
150 |
+
- [HiFi-GAN](https://github.com/jik876/hifi-gan) for GAN-based Vocoder's architecture design and training strategy.
|
151 |
+
- [Encodec](https://github.com/facebookresearch/encodec) for well-organized GAN Discriminator's architecture and basic blocks.
|
152 |
+
- [Latent Diffusion](https://github.com/CompVis/latent-diffusion) for model architecture design.
|
153 |
+
- [TensorFlowTTS](https://github.com/TensorSpeech/TensorFlowTTS) for preparing the MFA tools.
|
154 |
+
|
155 |
+
|
156 |
+
## ©️ License
|
157 |
+
|
158 |
+
Amphion is under the [MIT License](LICENSE). It is free for both research and commercial use cases.
|
159 |
+
|
160 |
+
## 📚 Citations
|
161 |
+
|
162 |
+
```bibtex
|
163 |
+
@inproceedings{amphion,
|
164 |
+
author={Zhang, Xueyao and Xue, Liumeng and Gu, Yicheng and Wang, Yuancheng and Li, Jiaqi and He, Haorui and Wang, Chaoren and Song, Ting and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zhang, Junan and Tang, Tze Ying and Zou, Lexiao and Wang, Mingxuan and Han, Jun and Chen, Kai and Li, Haizhou and Wu, Zhizheng},
|
165 |
+
title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
|
166 |
+
booktitle={{IEEE} Spoken Language Technology Workshop, {SLT} 2024},
|
167 |
+
year={2024}
|
168 |
+
}
|
169 |
+
```
|
models/__init__.py
ADDED
File without changes
|
models/base/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .new_trainer import BaseTrainer
|
7 |
+
from .new_inference import BaseInference
|
models/base/base_dataset.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torch.utils.data
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
import librosa
|
11 |
+
|
12 |
+
from utils.data_utils import *
|
13 |
+
from processors.acoustic_extractor import cal_normalized_mel
|
14 |
+
from text import text_to_sequence
|
15 |
+
from text.text_token_collation import phoneIDCollation
|
16 |
+
|
17 |
+
|
18 |
+
class BaseOfflineDataset(torch.utils.data.Dataset):
|
19 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
cfg: config
|
23 |
+
dataset: dataset name
|
24 |
+
is_valid: whether to use train or valid dataset
|
25 |
+
"""
|
26 |
+
|
27 |
+
assert isinstance(dataset, str)
|
28 |
+
|
29 |
+
# self.data_root = processed_data_dir
|
30 |
+
self.cfg = cfg
|
31 |
+
|
32 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
33 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
34 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
35 |
+
self.metadata = self.get_metadata()
|
36 |
+
|
37 |
+
"""
|
38 |
+
load spk2id and utt2spk from json file
|
39 |
+
spk2id: {spk1: 0, spk2: 1, ...}
|
40 |
+
utt2spk: {dataset_uid: spk1, ...}
|
41 |
+
"""
|
42 |
+
if cfg.preprocess.use_spkid:
|
43 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
44 |
+
with open(spk2id_path, "r") as f:
|
45 |
+
self.spk2id = json.load(f)
|
46 |
+
|
47 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
48 |
+
self.utt2spk = dict()
|
49 |
+
with open(utt2spk_path, "r") as f:
|
50 |
+
for line in f.readlines():
|
51 |
+
utt, spk = line.strip().split("\t")
|
52 |
+
self.utt2spk[utt] = spk
|
53 |
+
|
54 |
+
if cfg.preprocess.use_uv:
|
55 |
+
self.utt2uv_path = {}
|
56 |
+
for utt_info in self.metadata:
|
57 |
+
dataset = utt_info["Dataset"]
|
58 |
+
uid = utt_info["Uid"]
|
59 |
+
utt = "{}_{}".format(dataset, uid)
|
60 |
+
self.utt2uv_path[utt] = os.path.join(
|
61 |
+
cfg.preprocess.processed_dir,
|
62 |
+
dataset,
|
63 |
+
cfg.preprocess.uv_dir,
|
64 |
+
uid + ".npy",
|
65 |
+
)
|
66 |
+
|
67 |
+
if cfg.preprocess.use_frame_pitch:
|
68 |
+
self.utt2frame_pitch_path = {}
|
69 |
+
for utt_info in self.metadata:
|
70 |
+
dataset = utt_info["Dataset"]
|
71 |
+
uid = utt_info["Uid"]
|
72 |
+
utt = "{}_{}".format(dataset, uid)
|
73 |
+
|
74 |
+
self.utt2frame_pitch_path[utt] = os.path.join(
|
75 |
+
cfg.preprocess.processed_dir,
|
76 |
+
dataset,
|
77 |
+
cfg.preprocess.pitch_dir,
|
78 |
+
uid + ".npy",
|
79 |
+
)
|
80 |
+
|
81 |
+
if cfg.preprocess.use_frame_energy:
|
82 |
+
self.utt2frame_energy_path = {}
|
83 |
+
for utt_info in self.metadata:
|
84 |
+
dataset = utt_info["Dataset"]
|
85 |
+
uid = utt_info["Uid"]
|
86 |
+
utt = "{}_{}".format(dataset, uid)
|
87 |
+
|
88 |
+
self.utt2frame_energy_path[utt] = os.path.join(
|
89 |
+
cfg.preprocess.processed_dir,
|
90 |
+
dataset,
|
91 |
+
cfg.preprocess.energy_dir,
|
92 |
+
uid + ".npy",
|
93 |
+
)
|
94 |
+
|
95 |
+
if cfg.preprocess.use_mel:
|
96 |
+
self.utt2mel_path = {}
|
97 |
+
for utt_info in self.metadata:
|
98 |
+
dataset = utt_info["Dataset"]
|
99 |
+
uid = utt_info["Uid"]
|
100 |
+
utt = "{}_{}".format(dataset, uid)
|
101 |
+
|
102 |
+
self.utt2mel_path[utt] = os.path.join(
|
103 |
+
cfg.preprocess.processed_dir,
|
104 |
+
dataset,
|
105 |
+
cfg.preprocess.mel_dir,
|
106 |
+
uid + ".npy",
|
107 |
+
)
|
108 |
+
|
109 |
+
if cfg.preprocess.use_linear:
|
110 |
+
self.utt2linear_path = {}
|
111 |
+
for utt_info in self.metadata:
|
112 |
+
dataset = utt_info["Dataset"]
|
113 |
+
uid = utt_info["Uid"]
|
114 |
+
utt = "{}_{}".format(dataset, uid)
|
115 |
+
|
116 |
+
self.utt2linear_path[utt] = os.path.join(
|
117 |
+
cfg.preprocess.processed_dir,
|
118 |
+
dataset,
|
119 |
+
cfg.preprocess.linear_dir,
|
120 |
+
uid + ".npy",
|
121 |
+
)
|
122 |
+
|
123 |
+
if cfg.preprocess.use_audio:
|
124 |
+
self.utt2audio_path = {}
|
125 |
+
for utt_info in self.metadata:
|
126 |
+
dataset = utt_info["Dataset"]
|
127 |
+
uid = utt_info["Uid"]
|
128 |
+
utt = "{}_{}".format(dataset, uid)
|
129 |
+
|
130 |
+
self.utt2audio_path[utt] = os.path.join(
|
131 |
+
cfg.preprocess.processed_dir,
|
132 |
+
dataset,
|
133 |
+
cfg.preprocess.audio_dir,
|
134 |
+
uid + ".npy",
|
135 |
+
)
|
136 |
+
elif cfg.preprocess.use_label:
|
137 |
+
self.utt2label_path = {}
|
138 |
+
for utt_info in self.metadata:
|
139 |
+
dataset = utt_info["Dataset"]
|
140 |
+
uid = utt_info["Uid"]
|
141 |
+
utt = "{}_{}".format(dataset, uid)
|
142 |
+
|
143 |
+
self.utt2label_path[utt] = os.path.join(
|
144 |
+
cfg.preprocess.processed_dir,
|
145 |
+
dataset,
|
146 |
+
cfg.preprocess.label_dir,
|
147 |
+
uid + ".npy",
|
148 |
+
)
|
149 |
+
elif cfg.preprocess.use_one_hot:
|
150 |
+
self.utt2one_hot_path = {}
|
151 |
+
for utt_info in self.metadata:
|
152 |
+
dataset = utt_info["Dataset"]
|
153 |
+
uid = utt_info["Uid"]
|
154 |
+
utt = "{}_{}".format(dataset, uid)
|
155 |
+
|
156 |
+
self.utt2one_hot_path[utt] = os.path.join(
|
157 |
+
cfg.preprocess.processed_dir,
|
158 |
+
dataset,
|
159 |
+
cfg.preprocess.one_hot_dir,
|
160 |
+
uid + ".npy",
|
161 |
+
)
|
162 |
+
|
163 |
+
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
|
164 |
+
self.utt2seq = {}
|
165 |
+
for utt_info in self.metadata:
|
166 |
+
dataset = utt_info["Dataset"]
|
167 |
+
uid = utt_info["Uid"]
|
168 |
+
utt = "{}_{}".format(dataset, uid)
|
169 |
+
|
170 |
+
if cfg.preprocess.use_text:
|
171 |
+
text = utt_info["Text"]
|
172 |
+
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
|
173 |
+
elif cfg.preprocess.use_phone:
|
174 |
+
# load phoneme squence from phone file
|
175 |
+
phone_path = os.path.join(
|
176 |
+
processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
|
177 |
+
)
|
178 |
+
with open(phone_path, "r") as fin:
|
179 |
+
phones = fin.readlines()
|
180 |
+
assert len(phones) == 1
|
181 |
+
phones = phones[0].strip()
|
182 |
+
phones_seq = phones.split(" ")
|
183 |
+
|
184 |
+
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
|
185 |
+
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
|
186 |
+
|
187 |
+
self.utt2seq[utt] = sequence
|
188 |
+
|
189 |
+
def get_metadata(self):
|
190 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
191 |
+
metadata = json.load(f)
|
192 |
+
|
193 |
+
return metadata
|
194 |
+
|
195 |
+
def get_dataset_name(self):
|
196 |
+
return self.metadata[0]["Dataset"]
|
197 |
+
|
198 |
+
def __getitem__(self, index):
|
199 |
+
utt_info = self.metadata[index]
|
200 |
+
|
201 |
+
dataset = utt_info["Dataset"]
|
202 |
+
uid = utt_info["Uid"]
|
203 |
+
utt = "{}_{}".format(dataset, uid)
|
204 |
+
|
205 |
+
single_feature = dict()
|
206 |
+
|
207 |
+
if self.cfg.preprocess.use_spkid:
|
208 |
+
single_feature["spk_id"] = np.array(
|
209 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
|
210 |
+
)
|
211 |
+
|
212 |
+
if self.cfg.preprocess.use_mel:
|
213 |
+
mel = np.load(self.utt2mel_path[utt])
|
214 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
215 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
216 |
+
# do mel norm
|
217 |
+
mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
|
218 |
+
|
219 |
+
if "target_len" not in single_feature.keys():
|
220 |
+
single_feature["target_len"] = mel.shape[1]
|
221 |
+
single_feature["mel"] = mel.T # [T, n_mels]
|
222 |
+
|
223 |
+
if self.cfg.preprocess.use_linear:
|
224 |
+
linear = np.load(self.utt2linear_path[utt])
|
225 |
+
if "target_len" not in single_feature.keys():
|
226 |
+
single_feature["target_len"] = linear.shape[1]
|
227 |
+
single_feature["linear"] = linear.T # [T, n_linear]
|
228 |
+
|
229 |
+
if self.cfg.preprocess.use_frame_pitch:
|
230 |
+
frame_pitch_path = self.utt2frame_pitch_path[utt]
|
231 |
+
frame_pitch = np.load(frame_pitch_path)
|
232 |
+
if "target_len" not in single_feature.keys():
|
233 |
+
single_feature["target_len"] = len(frame_pitch)
|
234 |
+
aligned_frame_pitch = align_length(
|
235 |
+
frame_pitch, single_feature["target_len"]
|
236 |
+
)
|
237 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
238 |
+
|
239 |
+
if self.cfg.preprocess.use_uv:
|
240 |
+
frame_uv_path = self.utt2uv_path[utt]
|
241 |
+
frame_uv = np.load(frame_uv_path)
|
242 |
+
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
|
243 |
+
aligned_frame_uv = [
|
244 |
+
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
|
245 |
+
]
|
246 |
+
aligned_frame_uv = np.array(aligned_frame_uv)
|
247 |
+
single_feature["frame_uv"] = aligned_frame_uv
|
248 |
+
|
249 |
+
if self.cfg.preprocess.use_frame_energy:
|
250 |
+
frame_energy_path = self.utt2frame_energy_path[utt]
|
251 |
+
frame_energy = np.load(frame_energy_path)
|
252 |
+
if "target_len" not in single_feature.keys():
|
253 |
+
single_feature["target_len"] = len(frame_energy)
|
254 |
+
aligned_frame_energy = align_length(
|
255 |
+
frame_energy, single_feature["target_len"]
|
256 |
+
)
|
257 |
+
single_feature["frame_energy"] = aligned_frame_energy
|
258 |
+
|
259 |
+
if self.cfg.preprocess.use_audio:
|
260 |
+
audio = np.load(self.utt2audio_path[utt])
|
261 |
+
single_feature["audio"] = audio
|
262 |
+
single_feature["audio_len"] = audio.shape[0]
|
263 |
+
|
264 |
+
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
|
265 |
+
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
|
266 |
+
single_feature["phone_len"] = len(self.utt2seq[utt])
|
267 |
+
|
268 |
+
return single_feature
|
269 |
+
|
270 |
+
def __len__(self):
|
271 |
+
return len(self.metadata)
|
272 |
+
|
273 |
+
|
274 |
+
class BaseOfflineCollator(object):
|
275 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
276 |
+
|
277 |
+
def __init__(self, cfg):
|
278 |
+
self.cfg = cfg
|
279 |
+
|
280 |
+
def __call__(self, batch):
|
281 |
+
packed_batch_features = dict()
|
282 |
+
|
283 |
+
# mel: [b, T, n_mels]
|
284 |
+
# frame_pitch, frame_energy: [1, T]
|
285 |
+
# target_len: [b]
|
286 |
+
# spk_id: [b, 1]
|
287 |
+
# mask: [b, T, 1]
|
288 |
+
|
289 |
+
for key in batch[0].keys():
|
290 |
+
if key == "target_len":
|
291 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
292 |
+
[b["target_len"] for b in batch]
|
293 |
+
)
|
294 |
+
masks = [
|
295 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
296 |
+
]
|
297 |
+
packed_batch_features["mask"] = pad_sequence(
|
298 |
+
masks, batch_first=True, padding_value=0
|
299 |
+
)
|
300 |
+
elif key == "phone_len":
|
301 |
+
packed_batch_features["phone_len"] = torch.LongTensor(
|
302 |
+
[b["phone_len"] for b in batch]
|
303 |
+
)
|
304 |
+
masks = [
|
305 |
+
torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
|
306 |
+
]
|
307 |
+
packed_batch_features["phn_mask"] = pad_sequence(
|
308 |
+
masks, batch_first=True, padding_value=0
|
309 |
+
)
|
310 |
+
elif key == "audio_len":
|
311 |
+
packed_batch_features["audio_len"] = torch.LongTensor(
|
312 |
+
[b["audio_len"] for b in batch]
|
313 |
+
)
|
314 |
+
masks = [
|
315 |
+
torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
|
316 |
+
]
|
317 |
+
else:
|
318 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
319 |
+
packed_batch_features[key] = pad_sequence(
|
320 |
+
values, batch_first=True, padding_value=0
|
321 |
+
)
|
322 |
+
return packed_batch_features
|
323 |
+
|
324 |
+
|
325 |
+
class BaseOnlineDataset(torch.utils.data.Dataset):
|
326 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
327 |
+
"""
|
328 |
+
Args:
|
329 |
+
cfg: config
|
330 |
+
dataset: dataset name
|
331 |
+
is_valid: whether to use train or valid dataset
|
332 |
+
"""
|
333 |
+
assert isinstance(dataset, str)
|
334 |
+
|
335 |
+
self.cfg = cfg
|
336 |
+
self.sample_rate = cfg.preprocess.sample_rate
|
337 |
+
self.hop_size = self.cfg.preprocess.hop_size
|
338 |
+
|
339 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
340 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
341 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
342 |
+
self.metadata = self.get_metadata()
|
343 |
+
|
344 |
+
"""
|
345 |
+
load spk2id and utt2spk from json file
|
346 |
+
spk2id: {spk1: 0, spk2: 1, ...}
|
347 |
+
utt2spk: {dataset_uid: spk1, ...}
|
348 |
+
"""
|
349 |
+
if cfg.preprocess.use_spkid:
|
350 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
351 |
+
with open(spk2id_path, "r") as f:
|
352 |
+
self.spk2id = json.load(f)
|
353 |
+
|
354 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
355 |
+
self.utt2spk = dict()
|
356 |
+
with open(utt2spk_path, "r") as f:
|
357 |
+
for line in f.readlines():
|
358 |
+
utt, spk = line.strip().split("\t")
|
359 |
+
self.utt2spk[utt] = spk
|
360 |
+
|
361 |
+
def get_metadata(self):
|
362 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
363 |
+
metadata = json.load(f)
|
364 |
+
|
365 |
+
return metadata
|
366 |
+
|
367 |
+
def get_dataset_name(self):
|
368 |
+
return self.metadata[0]["Dataset"]
|
369 |
+
|
370 |
+
def __getitem__(self, index):
|
371 |
+
"""
|
372 |
+
single_feature:
|
373 |
+
wav: (T)
|
374 |
+
wav_len: int
|
375 |
+
target_len: int
|
376 |
+
mask: (n_frames, 1)
|
377 |
+
spk_id: (1)
|
378 |
+
"""
|
379 |
+
utt_item = self.metadata[index]
|
380 |
+
|
381 |
+
wav_path = utt_item["Path"]
|
382 |
+
wav, _ = librosa.load(wav_path, sr=self.sample_rate)
|
383 |
+
# wav: (T)
|
384 |
+
wav = torch.as_tensor(wav, dtype=torch.float32)
|
385 |
+
wav_len = len(wav)
|
386 |
+
# mask: (n_frames, 1)
|
387 |
+
frame_len = wav_len // self.hop_size
|
388 |
+
mask = torch.ones(frame_len, 1, dtype=torch.long)
|
389 |
+
|
390 |
+
single_feature = {
|
391 |
+
"wav": wav,
|
392 |
+
"wav_len": wav_len,
|
393 |
+
"target_len": frame_len,
|
394 |
+
"mask": mask,
|
395 |
+
}
|
396 |
+
|
397 |
+
if self.cfg.preprocess.use_spkid:
|
398 |
+
utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
|
399 |
+
single_feature["spk_id"] = torch.tensor(
|
400 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
|
401 |
+
)
|
402 |
+
|
403 |
+
return single_feature
|
404 |
+
|
405 |
+
def __len__(self):
|
406 |
+
return len(self.metadata)
|
407 |
+
|
408 |
+
|
409 |
+
class BaseOnlineCollator(object):
|
410 |
+
"""Zero-pads model inputs and targets based on number of frames per step (For on-the-fly features extraction, whose iterative item contains only wavs)"""
|
411 |
+
|
412 |
+
def __init__(self, cfg):
|
413 |
+
self.cfg = cfg
|
414 |
+
|
415 |
+
def __call__(self, batch):
|
416 |
+
"""
|
417 |
+
BaseOnlineDataset.__getitem__:
|
418 |
+
wav: (T,)
|
419 |
+
wav_len: int
|
420 |
+
target_len: int
|
421 |
+
mask: (n_frames, 1)
|
422 |
+
spk_id: (1)
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
wav: (B, T), torch.float32
|
426 |
+
wav_len: (B), torch.long
|
427 |
+
target_len: (B), torch.long
|
428 |
+
mask: (B, n_frames, 1), torch.long
|
429 |
+
spk_id: (B, 1), torch.int32
|
430 |
+
"""
|
431 |
+
packed_batch_features = dict()
|
432 |
+
|
433 |
+
for key in batch[0].keys():
|
434 |
+
if key in ["wav_len", "target_len"]:
|
435 |
+
packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
|
436 |
+
else:
|
437 |
+
packed_batch_features[key] = pad_sequence(
|
438 |
+
[b[key] for b in batch], batch_first=True, padding_value=0
|
439 |
+
)
|
440 |
+
return packed_batch_features
|
441 |
+
|
442 |
+
|
443 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
444 |
+
def __init__(self, cfg, args):
|
445 |
+
raise NotImplementedError
|
446 |
+
|
447 |
+
def get_metadata(self):
|
448 |
+
raise NotImplementedError
|
449 |
+
|
450 |
+
def __getitem__(self, index):
|
451 |
+
raise NotImplementedError
|
452 |
+
|
453 |
+
def __len__(self):
|
454 |
+
return len(self.metadata)
|
455 |
+
|
456 |
+
|
457 |
+
class BaseTestCollator(object):
|
458 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
459 |
+
|
460 |
+
def __init__(self, cfg):
|
461 |
+
raise NotImplementedError
|
462 |
+
|
463 |
+
def __call__(self, batch):
|
464 |
+
raise NotImplementedError
|
models/base/base_inference.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from models.vocoders.vocoder_inference import synthesis
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from utils.util import set_all_random_seed
|
19 |
+
from utils.util import load_config
|
20 |
+
|
21 |
+
|
22 |
+
def parse_vocoder(vocoder_dir):
|
23 |
+
r"""Parse vocoder config"""
|
24 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
25 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
26 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
27 |
+
ckpt_path = str(ckpt_list[0])
|
28 |
+
vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
|
29 |
+
vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
|
30 |
+
return vocoder_cfg, ckpt_path
|
31 |
+
|
32 |
+
|
33 |
+
class BaseInference(object):
|
34 |
+
def __init__(self, cfg, args):
|
35 |
+
self.cfg = cfg
|
36 |
+
self.args = args
|
37 |
+
self.model_type = cfg.model_type
|
38 |
+
self.avg_rtf = list()
|
39 |
+
set_all_random_seed(10086)
|
40 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
41 |
+
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
self.device = torch.device("cuda")
|
44 |
+
else:
|
45 |
+
self.device = torch.device("cpu")
|
46 |
+
torch.set_num_threads(10) # inference on 1 core cpu.
|
47 |
+
|
48 |
+
# Load acoustic model
|
49 |
+
self.model = self.create_model().to(self.device)
|
50 |
+
state_dict = self.load_state_dict()
|
51 |
+
self.load_model(state_dict)
|
52 |
+
self.model.eval()
|
53 |
+
|
54 |
+
# Load vocoder model if necessary
|
55 |
+
if self.args.checkpoint_dir_vocoder is not None:
|
56 |
+
self.get_vocoder_info()
|
57 |
+
|
58 |
+
def create_model(self):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def load_state_dict(self):
|
62 |
+
self.checkpoint_file = self.args.checkpoint_file
|
63 |
+
if self.checkpoint_file is None:
|
64 |
+
assert self.args.checkpoint_dir is not None
|
65 |
+
checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
|
66 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
67 |
+
self.checkpoint_file = os.path.join(
|
68 |
+
self.args.checkpoint_dir, checkpoint_filename
|
69 |
+
)
|
70 |
+
|
71 |
+
self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
|
72 |
+
|
73 |
+
print("Restore acoustic model from {}".format(self.checkpoint_file))
|
74 |
+
raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
|
75 |
+
self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
|
76 |
+
|
77 |
+
return raw_state_dict
|
78 |
+
|
79 |
+
def load_model(self, model):
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
def get_vocoder_info(self):
|
83 |
+
self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
|
84 |
+
self.vocoder_cfg = os.path.join(
|
85 |
+
os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
|
86 |
+
)
|
87 |
+
self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
|
88 |
+
self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
|
89 |
+
self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
|
90 |
+
|
91 |
+
def build_test_utt_data(self):
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
def build_testdata_loader(self, args, target_speaker=None):
|
95 |
+
datasets, collate = self.build_test_dataset()
|
96 |
+
self.test_dataset = datasets(self.cfg, args, target_speaker)
|
97 |
+
self.test_collate = collate(self.cfg)
|
98 |
+
self.test_batch_size = min(
|
99 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
100 |
+
)
|
101 |
+
test_loader = DataLoader(
|
102 |
+
self.test_dataset,
|
103 |
+
collate_fn=self.test_collate,
|
104 |
+
num_workers=self.args.num_workers,
|
105 |
+
batch_size=self.test_batch_size,
|
106 |
+
shuffle=False,
|
107 |
+
)
|
108 |
+
return test_loader
|
109 |
+
|
110 |
+
def inference_each_batch(self, batch_data):
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
def inference_for_batches(self, args, target_speaker=None):
|
114 |
+
###### Construct test_batch ######
|
115 |
+
loader = self.build_testdata_loader(args, target_speaker)
|
116 |
+
|
117 |
+
n_batch = len(loader)
|
118 |
+
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
119 |
+
print(
|
120 |
+
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
|
121 |
+
now, self.test_batch_size, n_batch
|
122 |
+
)
|
123 |
+
)
|
124 |
+
self.model.eval()
|
125 |
+
|
126 |
+
###### Inference for each batch ######
|
127 |
+
pred_res = []
|
128 |
+
with torch.no_grad():
|
129 |
+
for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
|
130 |
+
# Put the data to device
|
131 |
+
for k, v in batch_data.items():
|
132 |
+
batch_data[k] = batch_data[k].to(self.device)
|
133 |
+
|
134 |
+
y_pred, stats = self.inference_each_batch(batch_data)
|
135 |
+
|
136 |
+
pred_res += y_pred
|
137 |
+
|
138 |
+
return pred_res
|
139 |
+
|
140 |
+
def inference(self, feature):
|
141 |
+
raise NotImplementedError
|
142 |
+
|
143 |
+
def synthesis_by_vocoder(self, pred):
|
144 |
+
audios_pred = synthesis(
|
145 |
+
self.vocoder_cfg,
|
146 |
+
self.checkpoint_dir_vocoder,
|
147 |
+
len(pred),
|
148 |
+
pred,
|
149 |
+
)
|
150 |
+
return audios_pred
|
151 |
+
|
152 |
+
def __call__(self, utt):
|
153 |
+
feature = self.build_test_utt_data(utt)
|
154 |
+
start_time = time.time()
|
155 |
+
with torch.no_grad():
|
156 |
+
outputs = self.inference(feature)[0]
|
157 |
+
time_used = time.time() - start_time
|
158 |
+
rtf = time_used / (
|
159 |
+
outputs.shape[1]
|
160 |
+
* self.cfg.preprocess.hop_size
|
161 |
+
/ self.cfg.preprocess.sample_rate
|
162 |
+
)
|
163 |
+
print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
|
164 |
+
self.avg_rtf.append(rtf)
|
165 |
+
audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
|
166 |
+
return audios
|
167 |
+
|
168 |
+
|
169 |
+
def base_parser():
|
170 |
+
parser = argparse.ArgumentParser()
|
171 |
+
parser.add_argument(
|
172 |
+
"--config", default="config.json", help="json files for configurations."
|
173 |
+
)
|
174 |
+
parser.add_argument("--use_ddp_inference", default=False)
|
175 |
+
parser.add_argument("--n_workers", default=1, type=int)
|
176 |
+
parser.add_argument("--local_rank", default=-1, type=int)
|
177 |
+
parser.add_argument(
|
178 |
+
"--batch_size", default=1, type=int, help="Batch size for inference"
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--num_workers",
|
182 |
+
default=1,
|
183 |
+
type=int,
|
184 |
+
help="Worker number for inference dataloader",
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--checkpoint_dir",
|
188 |
+
type=str,
|
189 |
+
default=None,
|
190 |
+
help="Checkpoint dir including model file and configuration",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--checkpoint_file", help="checkpoint file", type=str, default=None
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--test_list", help="test utterance list for testing", type=str, default=None
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--checkpoint_dir_vocoder",
|
200 |
+
help="Vocoder's checkpoint dir including model file and configuration",
|
201 |
+
type=str,
|
202 |
+
default=None,
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--output_dir",
|
206 |
+
type=str,
|
207 |
+
default=None,
|
208 |
+
help="Output dir for saving generated results",
|
209 |
+
)
|
210 |
+
return parser
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
parser = base_parser()
|
215 |
+
args = parser.parse_args()
|
216 |
+
cfg = load_config(args.config)
|
217 |
+
|
218 |
+
# Build inference
|
219 |
+
inference = BaseInference(cfg, args)
|
220 |
+
inference()
|
models/base/base_sampler.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
|
9 |
+
from torch.utils.data import ConcatDataset, Dataset
|
10 |
+
from torch.utils.data.sampler import (
|
11 |
+
BatchSampler,
|
12 |
+
RandomSampler,
|
13 |
+
Sampler,
|
14 |
+
SequentialSampler,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class ScheduledSampler(Sampler):
|
19 |
+
"""A sampler that samples data from a given concat-dataset.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
|
23 |
+
batch_size (int): batch size
|
24 |
+
holistic_shuffle (bool): whether to shuffle the whole dataset or not
|
25 |
+
logger (logging.Logger): logger to print warning message
|
26 |
+
|
27 |
+
Usage:
|
28 |
+
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
|
29 |
+
>>> list(ScheduledSampler(ConcatDataset([[0, 1, 2], [3, 4, 5], [6, 7, 8]])))
|
30 |
+
[3, 4, 5, 0, 1, 2, 6, 7, 8]
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
concat_dataset,
|
36 |
+
batch_size,
|
37 |
+
holistic_shuffle,
|
38 |
+
logger=None,
|
39 |
+
loader_type="train",
|
40 |
+
):
|
41 |
+
if not isinstance(concat_dataset, ConcatDataset):
|
42 |
+
raise ValueError(
|
43 |
+
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
|
44 |
+
type(concat_dataset)
|
45 |
+
)
|
46 |
+
)
|
47 |
+
if not isinstance(batch_size, int):
|
48 |
+
raise ValueError(
|
49 |
+
"batch_size must be an integer, but got {}".format(type(batch_size))
|
50 |
+
)
|
51 |
+
if not isinstance(holistic_shuffle, bool):
|
52 |
+
raise ValueError(
|
53 |
+
"holistic_shuffle must be a boolean, but got {}".format(
|
54 |
+
type(holistic_shuffle)
|
55 |
+
)
|
56 |
+
)
|
57 |
+
|
58 |
+
self.concat_dataset = concat_dataset
|
59 |
+
self.batch_size = batch_size
|
60 |
+
self.holistic_shuffle = holistic_shuffle
|
61 |
+
|
62 |
+
affected_dataset_name = []
|
63 |
+
affected_dataset_len = []
|
64 |
+
for dataset in concat_dataset.datasets:
|
65 |
+
dataset_len = len(dataset)
|
66 |
+
dataset_name = dataset.get_dataset_name()
|
67 |
+
if dataset_len < batch_size:
|
68 |
+
affected_dataset_name.append(dataset_name)
|
69 |
+
affected_dataset_len.append(dataset_len)
|
70 |
+
|
71 |
+
self.type = loader_type
|
72 |
+
for dataset_name, dataset_len in zip(
|
73 |
+
affected_dataset_name, affected_dataset_len
|
74 |
+
):
|
75 |
+
if not loader_type == "valid":
|
76 |
+
logger.warning(
|
77 |
+
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
|
78 |
+
loader_type, dataset_name, dataset_len, batch_size
|
79 |
+
)
|
80 |
+
)
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
# the number of batches with drop last
|
84 |
+
num_of_batches = sum(
|
85 |
+
[
|
86 |
+
math.floor(len(dataset) / self.batch_size)
|
87 |
+
for dataset in self.concat_dataset.datasets
|
88 |
+
]
|
89 |
+
)
|
90 |
+
# if samples are not enough for one batch, we don't drop last
|
91 |
+
if self.type == "valid" and num_of_batches < 1:
|
92 |
+
return len(self.concat_dataset)
|
93 |
+
return num_of_batches * self.batch_size
|
94 |
+
|
95 |
+
def __iter__(self):
|
96 |
+
iters = []
|
97 |
+
for dataset in self.concat_dataset.datasets:
|
98 |
+
iters.append(
|
99 |
+
SequentialSampler(dataset).__iter__()
|
100 |
+
if not self.holistic_shuffle
|
101 |
+
else RandomSampler(dataset).__iter__()
|
102 |
+
)
|
103 |
+
# e.g. [0, 200, 400]
|
104 |
+
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
|
105 |
+
output_batches = []
|
106 |
+
for dataset_idx in range(len(self.concat_dataset.datasets)):
|
107 |
+
cur_batch = []
|
108 |
+
for idx in iters[dataset_idx]:
|
109 |
+
cur_batch.append(idx + init_indices[dataset_idx])
|
110 |
+
if len(cur_batch) == self.batch_size:
|
111 |
+
output_batches.append(cur_batch)
|
112 |
+
cur_batch = []
|
113 |
+
# if loader_type is valid, we don't need to drop last
|
114 |
+
if self.type == "valid" and len(cur_batch) > 0:
|
115 |
+
output_batches.append(cur_batch)
|
116 |
+
|
117 |
+
# force drop last in training
|
118 |
+
random.shuffle(output_batches)
|
119 |
+
output_indices = [item for sublist in output_batches for item in sublist]
|
120 |
+
return iter(output_indices)
|
121 |
+
|
122 |
+
|
123 |
+
def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
|
124 |
+
sampler = ScheduledSampler(
|
125 |
+
concat_dataset,
|
126 |
+
cfg.train.batch_size,
|
127 |
+
cfg.train.sampler.holistic_shuffle,
|
128 |
+
logger,
|
129 |
+
loader_type,
|
130 |
+
)
|
131 |
+
batch_sampler = BatchSampler(
|
132 |
+
sampler,
|
133 |
+
cfg.train.batch_size,
|
134 |
+
cfg.train.sampler.drop_last if not loader_type == "valid" else False,
|
135 |
+
)
|
136 |
+
return sampler, batch_sampler
|
137 |
+
|
138 |
+
|
139 |
+
class VariableSampler(BatchSampler):
|
140 |
+
def __init__(self, sampler, drop_last: bool, use_random_sampler=False):
|
141 |
+
self.data_list = sampler
|
142 |
+
if use_random_sampler:
|
143 |
+
self.sampler = RandomSampler(sampler)
|
144 |
+
else:
|
145 |
+
self.sampler = SequentialSampler(sampler)
|
146 |
+
|
147 |
+
super().__init__(self.sampler, 1, drop_last)
|
148 |
+
|
149 |
+
def __iter__(self):
|
150 |
+
for batch_ids in self.data_list:
|
151 |
+
yield batch_ids
|
152 |
+
|
153 |
+
def __len__(self):
|
154 |
+
if self.drop_last:
|
155 |
+
return len(self.sampler) // self.batch_size
|
156 |
+
else:
|
157 |
+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
models/base/base_trainer.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
from torch.nn.parallel import DistributedDataParallel
|
15 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
from models.base.base_sampler import BatchSampler
|
19 |
+
from utils.util import (
|
20 |
+
Logger,
|
21 |
+
remove_older_ckpt,
|
22 |
+
save_config,
|
23 |
+
set_all_random_seed,
|
24 |
+
ValueWindow,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class BaseTrainer(object):
|
29 |
+
def __init__(self, args, cfg):
|
30 |
+
self.args = args
|
31 |
+
self.log_dir = args.log_dir
|
32 |
+
self.cfg = cfg
|
33 |
+
|
34 |
+
self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
|
35 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
36 |
+
if not cfg.train.ddp or args.local_rank == 0:
|
37 |
+
self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
|
38 |
+
self.logger = self.build_logger()
|
39 |
+
self.time_window = ValueWindow(50)
|
40 |
+
|
41 |
+
self.step = 0
|
42 |
+
self.epoch = -1
|
43 |
+
self.max_epochs = self.cfg.train.epochs
|
44 |
+
self.max_steps = self.cfg.train.max_steps
|
45 |
+
|
46 |
+
# set random seed & init distributed training
|
47 |
+
set_all_random_seed(self.cfg.train.random_seed)
|
48 |
+
if cfg.train.ddp:
|
49 |
+
dist.init_process_group(backend="nccl")
|
50 |
+
|
51 |
+
if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
|
52 |
+
self.singers = self.build_singers_lut()
|
53 |
+
|
54 |
+
# setup data_loader
|
55 |
+
self.data_loader = self.build_data_loader()
|
56 |
+
|
57 |
+
# setup model & enable distributed training
|
58 |
+
self.model = self.build_model()
|
59 |
+
print(self.model)
|
60 |
+
|
61 |
+
if isinstance(self.model, dict):
|
62 |
+
for key, value in self.model.items():
|
63 |
+
value.cuda(self.args.local_rank)
|
64 |
+
if key == "PQMF":
|
65 |
+
continue
|
66 |
+
if cfg.train.ddp:
|
67 |
+
self.model[key] = DistributedDataParallel(
|
68 |
+
value, device_ids=[self.args.local_rank]
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.model.cuda(self.args.local_rank)
|
72 |
+
if cfg.train.ddp:
|
73 |
+
self.model = DistributedDataParallel(
|
74 |
+
self.model, device_ids=[self.args.local_rank]
|
75 |
+
)
|
76 |
+
|
77 |
+
# create criterion
|
78 |
+
self.criterion = self.build_criterion()
|
79 |
+
if isinstance(self.criterion, dict):
|
80 |
+
for key, value in self.criterion.items():
|
81 |
+
self.criterion[key].cuda(args.local_rank)
|
82 |
+
else:
|
83 |
+
self.criterion.cuda(self.args.local_rank)
|
84 |
+
|
85 |
+
# optimizer
|
86 |
+
self.optimizer = self.build_optimizer()
|
87 |
+
self.scheduler = self.build_scheduler()
|
88 |
+
|
89 |
+
# save config file
|
90 |
+
self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
|
91 |
+
|
92 |
+
def build_logger(self):
|
93 |
+
log_file = os.path.join(self.checkpoint_dir, "train.log")
|
94 |
+
logger = Logger(log_file, level=self.args.log_level).logger
|
95 |
+
|
96 |
+
return logger
|
97 |
+
|
98 |
+
def build_dataset(self):
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
def build_data_loader(self):
|
102 |
+
Dataset, Collator = self.build_dataset()
|
103 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
104 |
+
datasets_list = []
|
105 |
+
for dataset in self.cfg.dataset:
|
106 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
107 |
+
datasets_list.append(subdataset)
|
108 |
+
train_dataset = ConcatDataset(datasets_list)
|
109 |
+
|
110 |
+
train_collate = Collator(self.cfg)
|
111 |
+
# TODO: multi-GPU training
|
112 |
+
if self.cfg.train.ddp:
|
113 |
+
raise NotImplementedError("DDP is not supported yet.")
|
114 |
+
|
115 |
+
# sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
|
116 |
+
batch_sampler = BatchSampler(
|
117 |
+
cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
|
118 |
+
)
|
119 |
+
|
120 |
+
# use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
|
121 |
+
train_loader = DataLoader(
|
122 |
+
train_dataset,
|
123 |
+
collate_fn=train_collate,
|
124 |
+
num_workers=self.args.num_workers,
|
125 |
+
batch_sampler=batch_sampler,
|
126 |
+
pin_memory=False,
|
127 |
+
)
|
128 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
129 |
+
datasets_list = []
|
130 |
+
for dataset in self.cfg.dataset:
|
131 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
132 |
+
datasets_list.append(subdataset)
|
133 |
+
valid_dataset = ConcatDataset(datasets_list)
|
134 |
+
valid_collate = Collator(self.cfg)
|
135 |
+
batch_sampler = BatchSampler(
|
136 |
+
cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
|
137 |
+
)
|
138 |
+
valid_loader = DataLoader(
|
139 |
+
valid_dataset,
|
140 |
+
collate_fn=valid_collate,
|
141 |
+
num_workers=1,
|
142 |
+
batch_sampler=batch_sampler,
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
raise NotImplementedError("DDP is not supported yet.")
|
146 |
+
# valid_loader = None
|
147 |
+
data_loader = {"train": train_loader, "valid": valid_loader}
|
148 |
+
return data_loader
|
149 |
+
|
150 |
+
def build_singers_lut(self):
|
151 |
+
# combine singers
|
152 |
+
if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
|
153 |
+
singers = collections.OrderedDict()
|
154 |
+
else:
|
155 |
+
with open(
|
156 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
|
157 |
+
) as singer_file:
|
158 |
+
singers = json.load(singer_file)
|
159 |
+
singer_count = len(singers)
|
160 |
+
for dataset in self.cfg.dataset:
|
161 |
+
singer_lut_path = os.path.join(
|
162 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
163 |
+
)
|
164 |
+
with open(singer_lut_path, "r") as singer_lut_path:
|
165 |
+
singer_lut = json.load(singer_lut_path)
|
166 |
+
for singer in singer_lut.keys():
|
167 |
+
if singer not in singers:
|
168 |
+
singers[singer] = singer_count
|
169 |
+
singer_count += 1
|
170 |
+
with open(
|
171 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
|
172 |
+
) as singer_file:
|
173 |
+
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
|
174 |
+
print(
|
175 |
+
"singers have been dumped to {}".format(
|
176 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
|
177 |
+
)
|
178 |
+
)
|
179 |
+
return singers
|
180 |
+
|
181 |
+
def build_model(self):
|
182 |
+
raise NotImplementedError()
|
183 |
+
|
184 |
+
def build_optimizer(self):
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
def build_scheduler(self):
|
188 |
+
raise NotImplementedError()
|
189 |
+
|
190 |
+
def build_criterion(self):
|
191 |
+
raise NotImplementedError
|
192 |
+
|
193 |
+
def get_state_dict(self):
|
194 |
+
raise NotImplementedError
|
195 |
+
|
196 |
+
def save_config_file(self):
|
197 |
+
save_config(self.config_save_path, self.cfg)
|
198 |
+
|
199 |
+
# TODO, save without module.
|
200 |
+
def save_checkpoint(self, state_dict, saved_model_path):
|
201 |
+
torch.save(state_dict, saved_model_path)
|
202 |
+
|
203 |
+
def load_checkpoint(self):
|
204 |
+
checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
|
205 |
+
assert os.path.exists(checkpoint_path)
|
206 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
207 |
+
model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
|
208 |
+
assert os.path.exists(model_path)
|
209 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
210 |
+
self.logger.info(f"Re(store) from {model_path}")
|
211 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
212 |
+
return checkpoint
|
213 |
+
|
214 |
+
def load_model(self, checkpoint):
|
215 |
+
raise NotImplementedError
|
216 |
+
|
217 |
+
def restore(self):
|
218 |
+
checkpoint = self.load_checkpoint()
|
219 |
+
self.load_model(checkpoint)
|
220 |
+
|
221 |
+
def train_step(self, data):
|
222 |
+
raise NotImplementedError(
|
223 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
224 |
+
f"your sub-class of {self.__class__.__name__}. "
|
225 |
+
)
|
226 |
+
|
227 |
+
@torch.no_grad()
|
228 |
+
def eval_step(self):
|
229 |
+
raise NotImplementedError(
|
230 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
231 |
+
f"your sub-class of {self.__class__.__name__}. "
|
232 |
+
)
|
233 |
+
|
234 |
+
def write_summary(self, losses, stats):
|
235 |
+
raise NotImplementedError(
|
236 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
237 |
+
f"your sub-class of {self.__class__.__name__}. "
|
238 |
+
)
|
239 |
+
|
240 |
+
def write_valid_summary(self, losses, stats):
|
241 |
+
raise NotImplementedError(
|
242 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
243 |
+
f"your sub-class of {self.__class__.__name__}. "
|
244 |
+
)
|
245 |
+
|
246 |
+
def echo_log(self, losses, mode="Training"):
|
247 |
+
message = [
|
248 |
+
"{} - Epoch {} Step {}: [{:.3f} s/step]".format(
|
249 |
+
mode, self.epoch + 1, self.step, self.time_window.average
|
250 |
+
)
|
251 |
+
]
|
252 |
+
|
253 |
+
for key in sorted(losses.keys()):
|
254 |
+
if isinstance(losses[key], dict):
|
255 |
+
for k, v in losses[key].items():
|
256 |
+
message.append(
|
257 |
+
str(k).split("/")[-1] + "=" + str(round(float(v), 5))
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
message.append(
|
261 |
+
str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
|
262 |
+
)
|
263 |
+
self.logger.info(", ".join(message))
|
264 |
+
|
265 |
+
def eval_epoch(self):
|
266 |
+
self.logger.info("Validation...")
|
267 |
+
valid_losses = {}
|
268 |
+
for i, batch_data in enumerate(self.data_loader["valid"]):
|
269 |
+
for k, v in batch_data.items():
|
270 |
+
if isinstance(v, torch.Tensor):
|
271 |
+
batch_data[k] = v.cuda()
|
272 |
+
valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
|
273 |
+
for key in valid_loss:
|
274 |
+
if key not in valid_losses:
|
275 |
+
valid_losses[key] = 0
|
276 |
+
valid_losses[key] += valid_loss[key]
|
277 |
+
|
278 |
+
# Add mel and audio to the Tensorboard
|
279 |
+
# Average loss
|
280 |
+
for key in valid_losses:
|
281 |
+
valid_losses[key] /= i + 1
|
282 |
+
self.echo_log(valid_losses, "Valid")
|
283 |
+
return valid_losses, valid_stats
|
284 |
+
|
285 |
+
def train_epoch(self):
|
286 |
+
for i, batch_data in enumerate(self.data_loader["train"]):
|
287 |
+
start_time = time.time()
|
288 |
+
# Put the data to cuda device
|
289 |
+
for k, v in batch_data.items():
|
290 |
+
if isinstance(v, torch.Tensor):
|
291 |
+
batch_data[k] = v.cuda(self.args.local_rank)
|
292 |
+
|
293 |
+
# Training step
|
294 |
+
train_losses, train_stats, total_loss = self.train_step(batch_data)
|
295 |
+
self.time_window.append(time.time() - start_time)
|
296 |
+
|
297 |
+
if self.args.local_rank == 0 or not self.cfg.train.ddp:
|
298 |
+
if self.step % self.args.stdout_interval == 0:
|
299 |
+
self.echo_log(train_losses, "Training")
|
300 |
+
|
301 |
+
if self.step % self.cfg.train.save_summary_steps == 0:
|
302 |
+
self.logger.info(f"Save summary as step {self.step}")
|
303 |
+
self.write_summary(train_losses, train_stats)
|
304 |
+
|
305 |
+
if (
|
306 |
+
self.step % self.cfg.train.save_checkpoints_steps == 0
|
307 |
+
and self.step != 0
|
308 |
+
):
|
309 |
+
saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
|
310 |
+
self.step, total_loss
|
311 |
+
)
|
312 |
+
saved_model_path = os.path.join(
|
313 |
+
self.checkpoint_dir, saved_model_name
|
314 |
+
)
|
315 |
+
saved_state_dict = self.get_state_dict()
|
316 |
+
self.save_checkpoint(saved_state_dict, saved_model_path)
|
317 |
+
self.save_config_file()
|
318 |
+
# keep max n models
|
319 |
+
remove_older_ckpt(
|
320 |
+
saved_model_name,
|
321 |
+
self.checkpoint_dir,
|
322 |
+
max_to_keep=self.cfg.train.keep_checkpoint_max,
|
323 |
+
)
|
324 |
+
|
325 |
+
if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
|
326 |
+
if isinstance(self.model, dict):
|
327 |
+
for key in self.model.keys():
|
328 |
+
self.model[key].eval()
|
329 |
+
else:
|
330 |
+
self.model.eval()
|
331 |
+
# Evaluate one epoch and get average loss
|
332 |
+
valid_losses, valid_stats = self.eval_epoch()
|
333 |
+
if isinstance(self.model, dict):
|
334 |
+
for key in self.model.keys():
|
335 |
+
self.model[key].train()
|
336 |
+
else:
|
337 |
+
self.model.train()
|
338 |
+
# Write validation losses to summary.
|
339 |
+
self.write_valid_summary(valid_losses, valid_stats)
|
340 |
+
self.step += 1
|
341 |
+
|
342 |
+
def train(self):
|
343 |
+
for epoch in range(max(0, self.epoch), self.max_epochs):
|
344 |
+
self.train_epoch()
|
345 |
+
self.epoch += 1
|
346 |
+
if self.step > self.max_steps:
|
347 |
+
self.logger.info("Training finished!")
|
348 |
+
break
|
models/base/new_dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from abc import abstractmethod
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import json5
|
12 |
+
import torch
|
13 |
+
import yaml
|
14 |
+
|
15 |
+
|
16 |
+
# TODO: for training and validating
|
17 |
+
class BaseDataset(torch.utils.data.Dataset):
|
18 |
+
r"""Base dataset for training and validating."""
|
19 |
+
|
20 |
+
def __init__(self, args, cfg, is_valid=False):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
25 |
+
r"""Test dataset for inference."""
|
26 |
+
|
27 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
28 |
+
assert infer_type in ["from_dataset", "from_file"]
|
29 |
+
|
30 |
+
self.args = args
|
31 |
+
self.cfg = cfg
|
32 |
+
self.infer_type = infer_type
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def __getitem__(self, index):
|
36 |
+
pass
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.metadata)
|
40 |
+
|
41 |
+
def get_metadata(self):
|
42 |
+
path = Path(self.args.source)
|
43 |
+
if path.suffix == ".json" or path.suffix == ".jsonc":
|
44 |
+
metadata = json5.load(open(self.args.source, "r"))
|
45 |
+
elif path.suffix == ".yaml" or path.suffix == ".yml":
|
46 |
+
metadata = yaml.full_load(open(self.args.source, "r"))
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unsupported file type: {path.suffix}")
|
49 |
+
|
50 |
+
return metadata
|
models/base/new_inference.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
from abc import abstractmethod
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import accelerate
|
14 |
+
import json5
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from accelerate.logging import get_logger
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
from models.vocoders.vocoder_inference import synthesis
|
21 |
+
from utils.io import save_audio
|
22 |
+
from utils.util import load_config
|
23 |
+
from utils.audio_slicer import is_silence
|
24 |
+
|
25 |
+
EPS = 1.0e-12
|
26 |
+
|
27 |
+
|
28 |
+
class BaseInference(object):
|
29 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
start = time.monotonic_ns()
|
33 |
+
self.args = args
|
34 |
+
self.cfg = cfg
|
35 |
+
|
36 |
+
assert infer_type in ["from_dataset", "from_file"]
|
37 |
+
self.infer_type = infer_type
|
38 |
+
|
39 |
+
# init with accelerate
|
40 |
+
self.accelerator = accelerate.Accelerator()
|
41 |
+
self.accelerator.wait_for_everyone()
|
42 |
+
|
43 |
+
# Use accelerate logger for distributed inference
|
44 |
+
with self.accelerator.main_process_first():
|
45 |
+
self.logger = get_logger("inference", log_level=args.log_level)
|
46 |
+
|
47 |
+
# Log some info
|
48 |
+
self.logger.info("=" * 56)
|
49 |
+
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
|
50 |
+
self.logger.info("=" * 56)
|
51 |
+
self.logger.info("\n")
|
52 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
53 |
+
|
54 |
+
self.acoustics_dir = args.acoustics_dir
|
55 |
+
self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
|
56 |
+
self.vocoder_dir = args.vocoder_dir
|
57 |
+
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
|
58 |
+
# should be in svc inferencer
|
59 |
+
# self.target_singer = args.target_singer
|
60 |
+
# self.logger.info(f"Target singers: {args.target_singer}")
|
61 |
+
# self.trans_key = args.trans_key
|
62 |
+
# self.logger.info(f"Trans key: {args.trans_key}")
|
63 |
+
|
64 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
65 |
+
|
66 |
+
# set random seed
|
67 |
+
with self.accelerator.main_process_first():
|
68 |
+
start = time.monotonic_ns()
|
69 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
70 |
+
end = time.monotonic_ns()
|
71 |
+
self.logger.debug(
|
72 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
73 |
+
)
|
74 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
75 |
+
|
76 |
+
# setup data_loader
|
77 |
+
with self.accelerator.main_process_first():
|
78 |
+
self.logger.info("Building dataset...")
|
79 |
+
start = time.monotonic_ns()
|
80 |
+
self.test_dataloader = self._build_dataloader()
|
81 |
+
end = time.monotonic_ns()
|
82 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
83 |
+
|
84 |
+
# setup model
|
85 |
+
with self.accelerator.main_process_first():
|
86 |
+
self.logger.info("Building model...")
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self.model = self._build_model()
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
# self.logger.debug(self.model)
|
91 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
|
92 |
+
|
93 |
+
# init with accelerate
|
94 |
+
self.logger.info("Initializing accelerate...")
|
95 |
+
start = time.monotonic_ns()
|
96 |
+
self.accelerator = accelerate.Accelerator()
|
97 |
+
self.model = self.accelerator.prepare(self.model)
|
98 |
+
end = time.monotonic_ns()
|
99 |
+
self.accelerator.wait_for_everyone()
|
100 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
|
101 |
+
|
102 |
+
with self.accelerator.main_process_first():
|
103 |
+
self.logger.info("Loading checkpoint...")
|
104 |
+
start = time.monotonic_ns()
|
105 |
+
# TODO: Also, suppose only use latest one yet
|
106 |
+
self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
|
107 |
+
end = time.monotonic_ns()
|
108 |
+
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
|
109 |
+
|
110 |
+
self.model.eval()
|
111 |
+
self.accelerator.wait_for_everyone()
|
112 |
+
|
113 |
+
### Abstract methods ###
|
114 |
+
@abstractmethod
|
115 |
+
def _build_test_dataset(self):
|
116 |
+
pass
|
117 |
+
|
118 |
+
@abstractmethod
|
119 |
+
def _build_model(self):
|
120 |
+
pass
|
121 |
+
|
122 |
+
@abstractmethod
|
123 |
+
@torch.inference_mode()
|
124 |
+
def _inference_each_batch(self, batch_data):
|
125 |
+
pass
|
126 |
+
|
127 |
+
### Abstract methods end ###
|
128 |
+
|
129 |
+
@torch.inference_mode()
|
130 |
+
def inference(self):
|
131 |
+
for i, batch in enumerate(self.test_dataloader):
|
132 |
+
y_pred = self._inference_each_batch(batch).cpu()
|
133 |
+
|
134 |
+
# Judge whether the min-max normliazation is used
|
135 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
136 |
+
mel_min, mel_max = self.test_dataset.target_mel_extrema
|
137 |
+
y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
|
138 |
+
|
139 |
+
y_ls = y_pred.chunk(self.test_batch_size)
|
140 |
+
tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
|
141 |
+
j = 0
|
142 |
+
for it, l in zip(y_ls, tgt_ls):
|
143 |
+
l = l.item()
|
144 |
+
it = it.squeeze(0)[:l]
|
145 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
146 |
+
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
|
147 |
+
j += 1
|
148 |
+
|
149 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
150 |
+
|
151 |
+
res = synthesis(
|
152 |
+
cfg=vocoder_cfg,
|
153 |
+
vocoder_weight_file=vocoder_ckpt,
|
154 |
+
n_samples=None,
|
155 |
+
pred=[
|
156 |
+
torch.load(
|
157 |
+
os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
|
158 |
+
).numpy(force=True)
|
159 |
+
for i in self.test_dataset.metadata
|
160 |
+
],
|
161 |
+
)
|
162 |
+
|
163 |
+
output_audio_files = []
|
164 |
+
for it, wav in zip(self.test_dataset.metadata, res):
|
165 |
+
uid = it["Uid"]
|
166 |
+
file = os.path.join(self.args.output_dir, f"{uid}.wav")
|
167 |
+
output_audio_files.append(file)
|
168 |
+
|
169 |
+
wav = wav.numpy(force=True)
|
170 |
+
save_audio(
|
171 |
+
file,
|
172 |
+
wav,
|
173 |
+
self.cfg.preprocess.sample_rate,
|
174 |
+
add_silence=False,
|
175 |
+
turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
|
176 |
+
)
|
177 |
+
os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
|
178 |
+
|
179 |
+
return sorted(output_audio_files)
|
180 |
+
|
181 |
+
# TODO: LEGACY CODE
|
182 |
+
def _build_dataloader(self):
|
183 |
+
datasets, collate = self._build_test_dataset()
|
184 |
+
self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
|
185 |
+
self.test_collate = collate(self.cfg)
|
186 |
+
self.test_batch_size = min(
|
187 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
188 |
+
)
|
189 |
+
test_dataloader = DataLoader(
|
190 |
+
self.test_dataset,
|
191 |
+
collate_fn=self.test_collate,
|
192 |
+
num_workers=1,
|
193 |
+
batch_size=self.test_batch_size,
|
194 |
+
shuffle=False,
|
195 |
+
)
|
196 |
+
return test_dataloader
|
197 |
+
|
198 |
+
def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
|
199 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
200 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
201 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
202 |
+
method after** ``accelerator.prepare()``.
|
203 |
+
"""
|
204 |
+
if checkpoint_path is None:
|
205 |
+
ls = []
|
206 |
+
for i in Path(checkpoint_dir).iterdir():
|
207 |
+
if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
|
208 |
+
ls.append(i)
|
209 |
+
ls.sort(
|
210 |
+
key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
|
211 |
+
)
|
212 |
+
checkpoint_path = ls[0]
|
213 |
+
else:
|
214 |
+
checkpoint_path = Path(checkpoint_path)
|
215 |
+
self.accelerator.load_state(str(checkpoint_path))
|
216 |
+
# set epoch and step
|
217 |
+
self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
|
218 |
+
self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
|
219 |
+
return str(checkpoint_path)
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def _set_random_seed(seed):
|
223 |
+
r"""Set random seed for all possible random modules."""
|
224 |
+
random.seed(seed)
|
225 |
+
np.random.seed(seed)
|
226 |
+
torch.random.manual_seed(seed)
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def _parse_vocoder(vocoder_dir):
|
230 |
+
r"""Parse vocoder config"""
|
231 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
232 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
233 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
234 |
+
ckpt_path = str(ckpt_list[0])
|
235 |
+
vocoder_cfg = load_config(
|
236 |
+
os.path.join(vocoder_dir, "args.json"), lowercase=True
|
237 |
+
)
|
238 |
+
return vocoder_cfg, ckpt_path
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def __count_parameters(model):
|
242 |
+
return sum(p.numel() for p in model.parameters())
|
243 |
+
|
244 |
+
def __dump_cfg(self, path):
|
245 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
246 |
+
json5.dump(
|
247 |
+
self.cfg,
|
248 |
+
open(path, "w"),
|
249 |
+
indent=4,
|
250 |
+
sort_keys=True,
|
251 |
+
ensure_ascii=False,
|
252 |
+
quote_keys=True,
|
253 |
+
)
|
models/base/new_trainer.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import shutil
|
10 |
+
import time
|
11 |
+
from abc import abstractmethod
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import accelerate
|
15 |
+
import json5
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from accelerate.logging import get_logger
|
19 |
+
from accelerate.utils import ProjectConfiguration
|
20 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from models.base.base_sampler import build_samplers
|
24 |
+
from optimizer.optimizers import NoamLR
|
25 |
+
|
26 |
+
|
27 |
+
class BaseTrainer(object):
|
28 |
+
r"""The base trainer for all tasks. Any trainer should inherit from this class."""
|
29 |
+
|
30 |
+
def __init__(self, args=None, cfg=None):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.args = args
|
34 |
+
self.cfg = cfg
|
35 |
+
|
36 |
+
cfg.exp_name = args.exp_name
|
37 |
+
|
38 |
+
# init with accelerate
|
39 |
+
self._init_accelerator()
|
40 |
+
self.accelerator.wait_for_everyone()
|
41 |
+
|
42 |
+
# Use accelerate logger for distributed training
|
43 |
+
with self.accelerator.main_process_first():
|
44 |
+
self.logger = get_logger(args.exp_name, log_level=args.log_level)
|
45 |
+
|
46 |
+
# Log some info
|
47 |
+
self.logger.info("=" * 56)
|
48 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
49 |
+
self.logger.info("=" * 56)
|
50 |
+
self.logger.info("\n")
|
51 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
52 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
53 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
54 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
55 |
+
if self.accelerator.is_main_process:
|
56 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
57 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
58 |
+
|
59 |
+
# init counts
|
60 |
+
self.batch_count: int = 0
|
61 |
+
self.step: int = 0
|
62 |
+
self.epoch: int = 0
|
63 |
+
self.max_epoch = (
|
64 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
65 |
+
)
|
66 |
+
self.logger.info(
|
67 |
+
"Max epoch: {}".format(
|
68 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
# Check values
|
73 |
+
if self.accelerator.is_main_process:
|
74 |
+
self.__check_basic_configs()
|
75 |
+
# Set runtime configs
|
76 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
77 |
+
self.checkpoints_path = [
|
78 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
79 |
+
]
|
80 |
+
self.keep_last = [
|
81 |
+
i if i > 0 else float("inf") for i in self.cfg.train.keep_last
|
82 |
+
]
|
83 |
+
self.run_eval = self.cfg.train.run_eval
|
84 |
+
|
85 |
+
# set random seed
|
86 |
+
with self.accelerator.main_process_first():
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
self.logger.debug(
|
91 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
92 |
+
)
|
93 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
94 |
+
|
95 |
+
# setup data_loader
|
96 |
+
with self.accelerator.main_process_first():
|
97 |
+
self.logger.info("Building dataset...")
|
98 |
+
start = time.monotonic_ns()
|
99 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
100 |
+
end = time.monotonic_ns()
|
101 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
102 |
+
|
103 |
+
# setup model
|
104 |
+
with self.accelerator.main_process_first():
|
105 |
+
self.logger.info("Building model...")
|
106 |
+
start = time.monotonic_ns()
|
107 |
+
self.model = self._build_model()
|
108 |
+
end = time.monotonic_ns()
|
109 |
+
self.logger.debug(self.model)
|
110 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
111 |
+
self.logger.info(
|
112 |
+
f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
|
113 |
+
)
|
114 |
+
# optimizer & scheduler
|
115 |
+
with self.accelerator.main_process_first():
|
116 |
+
self.logger.info("Building optimizer and scheduler...")
|
117 |
+
start = time.monotonic_ns()
|
118 |
+
self.optimizer = self._build_optimizer()
|
119 |
+
self.scheduler = self._build_scheduler()
|
120 |
+
end = time.monotonic_ns()
|
121 |
+
self.logger.info(
|
122 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
123 |
+
)
|
124 |
+
|
125 |
+
# accelerate prepare
|
126 |
+
self.logger.info("Initializing accelerate...")
|
127 |
+
start = time.monotonic_ns()
|
128 |
+
self._accelerator_prepare()
|
129 |
+
end = time.monotonic_ns()
|
130 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
|
131 |
+
|
132 |
+
# create criterion
|
133 |
+
with self.accelerator.main_process_first():
|
134 |
+
self.logger.info("Building criterion...")
|
135 |
+
start = time.monotonic_ns()
|
136 |
+
self.criterion = self._build_criterion()
|
137 |
+
end = time.monotonic_ns()
|
138 |
+
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
|
139 |
+
|
140 |
+
# Resume or Finetune
|
141 |
+
with self.accelerator.main_process_first():
|
142 |
+
if args.resume:
|
143 |
+
if args.resume_from_ckpt_path == "":
|
144 |
+
## Automatically resume according to the current exprimental name
|
145 |
+
self.logger.info(
|
146 |
+
"Automatically resuming from latest checkpoint in {}...".format(
|
147 |
+
self.checkpoint_dir
|
148 |
+
)
|
149 |
+
)
|
150 |
+
start = time.monotonic_ns()
|
151 |
+
ckpt_path = self._load_model(
|
152 |
+
checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
|
153 |
+
)
|
154 |
+
end = time.monotonic_ns()
|
155 |
+
self.logger.info(
|
156 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
157 |
+
)
|
158 |
+
self.checkpoints_path = json.load(
|
159 |
+
open(os.path.join(ckpt_path, "ckpts.json"), "r")
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
## Resume from the given checkpoint path
|
163 |
+
if not os.path.exists(args.resume_from_ckpt_path):
|
164 |
+
raise ValueError(
|
165 |
+
"[Error] The resumed checkpoint path {} don't exist.".format(
|
166 |
+
args.resume_from_ckpt_path
|
167 |
+
)
|
168 |
+
)
|
169 |
+
self.logger.info(
|
170 |
+
"Resuming from {}...".format(args.resume_from_ckpt_path)
|
171 |
+
)
|
172 |
+
start = time.monotonic_ns()
|
173 |
+
ckpt_path = self._load_model(
|
174 |
+
checkpoint_path=args.resume_from_ckpt_path,
|
175 |
+
resume_type=args.resume_type,
|
176 |
+
)
|
177 |
+
end = time.monotonic_ns()
|
178 |
+
self.logger.info(
|
179 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
180 |
+
)
|
181 |
+
|
182 |
+
# save config file path
|
183 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
184 |
+
|
185 |
+
def _accelerator_prepare(self):
|
186 |
+
(
|
187 |
+
self.train_dataloader,
|
188 |
+
self.valid_dataloader,
|
189 |
+
self.model,
|
190 |
+
self.optimizer,
|
191 |
+
self.scheduler,
|
192 |
+
) = self.accelerator.prepare(
|
193 |
+
self.train_dataloader,
|
194 |
+
self.valid_dataloader,
|
195 |
+
self.model,
|
196 |
+
self.optimizer,
|
197 |
+
self.scheduler,
|
198 |
+
)
|
199 |
+
|
200 |
+
### Following are abstract methods that should be implemented in child classes ###
|
201 |
+
@abstractmethod
|
202 |
+
def _build_dataset(self):
|
203 |
+
r"""Build dataset for model training/validating/evaluating."""
|
204 |
+
pass
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
@abstractmethod
|
208 |
+
def _build_criterion():
|
209 |
+
r"""Build criterion function for model loss calculation."""
|
210 |
+
pass
|
211 |
+
|
212 |
+
@abstractmethod
|
213 |
+
def _build_model(self):
|
214 |
+
r"""Build model for training/validating/evaluating."""
|
215 |
+
pass
|
216 |
+
|
217 |
+
@abstractmethod
|
218 |
+
def _forward_step(self, batch):
|
219 |
+
r"""One forward step of the neural network. This abstract method is trying to
|
220 |
+
unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
|
221 |
+
However, for special case that using different forward step pattern for
|
222 |
+
training and validating, you could just override this method with ``pass`` and
|
223 |
+
implement ``_train_step`` and ``_valid_step`` separately.
|
224 |
+
"""
|
225 |
+
pass
|
226 |
+
|
227 |
+
@abstractmethod
|
228 |
+
def _save_auxiliary_states(self):
|
229 |
+
r"""To save some auxiliary states when saving model's ckpt"""
|
230 |
+
pass
|
231 |
+
|
232 |
+
### Abstract methods end ###
|
233 |
+
|
234 |
+
### THIS IS MAIN ENTRY ###
|
235 |
+
def train_loop(self):
|
236 |
+
r"""Training loop. The public entry of training process."""
|
237 |
+
# Wait everyone to prepare before we move on
|
238 |
+
self.accelerator.wait_for_everyone()
|
239 |
+
# dump config file
|
240 |
+
if self.accelerator.is_main_process:
|
241 |
+
self.__dump_cfg(self.config_save_path)
|
242 |
+
self.model.train()
|
243 |
+
self.optimizer.zero_grad()
|
244 |
+
# Wait to ensure good to go
|
245 |
+
self.accelerator.wait_for_everyone()
|
246 |
+
while self.epoch < self.max_epoch:
|
247 |
+
self.logger.info("\n")
|
248 |
+
self.logger.info("-" * 32)
|
249 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
250 |
+
|
251 |
+
### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
|
252 |
+
### It's inconvenient for the model with multiple losses
|
253 |
+
# Do training & validating epoch
|
254 |
+
train_loss = self._train_epoch()
|
255 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
|
256 |
+
valid_loss = self._valid_epoch()
|
257 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
|
258 |
+
self.accelerator.log(
|
259 |
+
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
|
260 |
+
step=self.epoch,
|
261 |
+
)
|
262 |
+
|
263 |
+
self.accelerator.wait_for_everyone()
|
264 |
+
# TODO: what is scheduler?
|
265 |
+
self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
|
266 |
+
|
267 |
+
# Check if hit save_checkpoint_stride and run_eval
|
268 |
+
run_eval = False
|
269 |
+
if self.accelerator.is_main_process:
|
270 |
+
save_checkpoint = False
|
271 |
+
hit_dix = []
|
272 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
273 |
+
if self.epoch % num == 0:
|
274 |
+
save_checkpoint = True
|
275 |
+
hit_dix.append(i)
|
276 |
+
run_eval |= self.run_eval[i]
|
277 |
+
|
278 |
+
self.accelerator.wait_for_everyone()
|
279 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
280 |
+
path = os.path.join(
|
281 |
+
self.checkpoint_dir,
|
282 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
283 |
+
self.epoch, self.step, train_loss
|
284 |
+
),
|
285 |
+
)
|
286 |
+
self.tmp_checkpoint_save_path = path
|
287 |
+
self.accelerator.save_state(path)
|
288 |
+
print(f"save checkpoint in {path}")
|
289 |
+
json.dump(
|
290 |
+
self.checkpoints_path,
|
291 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
292 |
+
ensure_ascii=False,
|
293 |
+
indent=4,
|
294 |
+
)
|
295 |
+
self._save_auxiliary_states()
|
296 |
+
|
297 |
+
# Remove old checkpoints
|
298 |
+
to_remove = []
|
299 |
+
for idx in hit_dix:
|
300 |
+
self.checkpoints_path[idx].append(path)
|
301 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
302 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
303 |
+
|
304 |
+
# Search conflicts
|
305 |
+
total = set()
|
306 |
+
for i in self.checkpoints_path:
|
307 |
+
total |= set(i)
|
308 |
+
do_remove = set()
|
309 |
+
for idx, path in to_remove[::-1]:
|
310 |
+
if path in total:
|
311 |
+
self.checkpoints_path[idx].insert(0, path)
|
312 |
+
else:
|
313 |
+
do_remove.add(path)
|
314 |
+
|
315 |
+
# Remove old checkpoints
|
316 |
+
for path in do_remove:
|
317 |
+
shutil.rmtree(path, ignore_errors=True)
|
318 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
319 |
+
|
320 |
+
self.accelerator.wait_for_everyone()
|
321 |
+
if run_eval:
|
322 |
+
# TODO: run evaluation
|
323 |
+
pass
|
324 |
+
|
325 |
+
# Update info for each epoch
|
326 |
+
self.epoch += 1
|
327 |
+
|
328 |
+
# Finish training and save final checkpoint
|
329 |
+
self.accelerator.wait_for_everyone()
|
330 |
+
if self.accelerator.is_main_process:
|
331 |
+
self.accelerator.save_state(
|
332 |
+
os.path.join(
|
333 |
+
self.checkpoint_dir,
|
334 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
335 |
+
self.epoch, self.step, valid_loss
|
336 |
+
),
|
337 |
+
)
|
338 |
+
)
|
339 |
+
self._save_auxiliary_states()
|
340 |
+
|
341 |
+
self.accelerator.end_training()
|
342 |
+
|
343 |
+
### Following are methods that can be used directly in child classes ###
|
344 |
+
def _train_epoch(self):
|
345 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
346 |
+
one epoch. See ``train_loop`` for usage.
|
347 |
+
"""
|
348 |
+
self.model.train()
|
349 |
+
epoch_sum_loss: float = 0.0
|
350 |
+
epoch_step: int = 0
|
351 |
+
for batch in tqdm(
|
352 |
+
self.train_dataloader,
|
353 |
+
desc=f"Training Epoch {self.epoch}",
|
354 |
+
unit="batch",
|
355 |
+
colour="GREEN",
|
356 |
+
leave=False,
|
357 |
+
dynamic_ncols=True,
|
358 |
+
smoothing=0.04,
|
359 |
+
disable=not self.accelerator.is_main_process,
|
360 |
+
):
|
361 |
+
# Do training step and BP
|
362 |
+
with self.accelerator.accumulate(self.model):
|
363 |
+
loss = self._train_step(batch)
|
364 |
+
self.accelerator.backward(loss)
|
365 |
+
self.optimizer.step()
|
366 |
+
self.optimizer.zero_grad()
|
367 |
+
self.batch_count += 1
|
368 |
+
|
369 |
+
# Update info for each step
|
370 |
+
# TODO: step means BP counts or batch counts?
|
371 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
372 |
+
epoch_sum_loss += loss
|
373 |
+
self.accelerator.log(
|
374 |
+
{
|
375 |
+
"Step/Train Loss": loss,
|
376 |
+
"Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
|
377 |
+
},
|
378 |
+
step=self.step,
|
379 |
+
)
|
380 |
+
self.step += 1
|
381 |
+
epoch_step += 1
|
382 |
+
|
383 |
+
self.accelerator.wait_for_everyone()
|
384 |
+
return (
|
385 |
+
epoch_sum_loss
|
386 |
+
/ len(self.train_dataloader)
|
387 |
+
* self.cfg.train.gradient_accumulation_step
|
388 |
+
)
|
389 |
+
|
390 |
+
@torch.inference_mode()
|
391 |
+
def _valid_epoch(self):
|
392 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
393 |
+
one epoch. See ``train_loop`` for usage.
|
394 |
+
"""
|
395 |
+
self.model.eval()
|
396 |
+
epoch_sum_loss = 0.0
|
397 |
+
for batch in tqdm(
|
398 |
+
self.valid_dataloader,
|
399 |
+
desc=f"Validating Epoch {self.epoch}",
|
400 |
+
unit="batch",
|
401 |
+
colour="GREEN",
|
402 |
+
leave=False,
|
403 |
+
dynamic_ncols=True,
|
404 |
+
smoothing=0.04,
|
405 |
+
disable=not self.accelerator.is_main_process,
|
406 |
+
):
|
407 |
+
batch_loss = self._valid_step(batch)
|
408 |
+
epoch_sum_loss += batch_loss.item()
|
409 |
+
|
410 |
+
self.accelerator.wait_for_everyone()
|
411 |
+
return epoch_sum_loss / len(self.valid_dataloader)
|
412 |
+
|
413 |
+
def _train_step(self, batch):
|
414 |
+
r"""Training forward step. Should return average loss of a sample over
|
415 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
416 |
+
See ``_train_epoch`` for usage.
|
417 |
+
"""
|
418 |
+
return self._forward_step(batch)
|
419 |
+
|
420 |
+
@torch.inference_mode()
|
421 |
+
def _valid_step(self, batch):
|
422 |
+
r"""Testing forward step. Should return average loss of a sample over
|
423 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
424 |
+
See ``_test_epoch`` for usage.
|
425 |
+
"""
|
426 |
+
return self._forward_step(batch)
|
427 |
+
|
428 |
+
def _load_model(
|
429 |
+
self,
|
430 |
+
checkpoint_dir: str = None,
|
431 |
+
checkpoint_path: str = None,
|
432 |
+
resume_type: str = "",
|
433 |
+
):
|
434 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
435 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
436 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
437 |
+
method after** ``accelerator.prepare()``.
|
438 |
+
"""
|
439 |
+
if checkpoint_path is None:
|
440 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
441 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
442 |
+
checkpoint_path = ls[0]
|
443 |
+
self.logger.info("Resume from {}...".format(checkpoint_path))
|
444 |
+
|
445 |
+
if resume_type in ["resume", ""]:
|
446 |
+
# Load all the things, including model weights, optimizer, scheduler, and random states.
|
447 |
+
self.accelerator.load_state(input_dir=checkpoint_path)
|
448 |
+
|
449 |
+
# set epoch and step
|
450 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
451 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
452 |
+
|
453 |
+
elif resume_type == "finetune":
|
454 |
+
# Load only the model weights
|
455 |
+
accelerate.load_checkpoint_and_dispatch(
|
456 |
+
self.accelerator.unwrap_model(self.model),
|
457 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
458 |
+
)
|
459 |
+
self.logger.info("Load model weights for finetune...")
|
460 |
+
|
461 |
+
else:
|
462 |
+
raise ValueError("Resume_type must be `resume` or `finetune`.")
|
463 |
+
|
464 |
+
return checkpoint_path
|
465 |
+
|
466 |
+
def _build_dataloader(self):
|
467 |
+
Dataset, Collator = self._build_dataset()
|
468 |
+
|
469 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
470 |
+
datasets_list = []
|
471 |
+
for dataset in self.cfg.dataset:
|
472 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
473 |
+
datasets_list.append(subdataset)
|
474 |
+
train_dataset = ConcatDataset(datasets_list)
|
475 |
+
train_collate = Collator(self.cfg)
|
476 |
+
_, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
|
477 |
+
self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
|
478 |
+
self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
|
479 |
+
# TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
|
480 |
+
train_loader = DataLoader(
|
481 |
+
train_dataset,
|
482 |
+
# shuffle=True,
|
483 |
+
collate_fn=train_collate,
|
484 |
+
batch_sampler=batch_sampler,
|
485 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
486 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
487 |
+
)
|
488 |
+
|
489 |
+
# Build valid dataloader
|
490 |
+
datasets_list = []
|
491 |
+
for dataset in self.cfg.dataset:
|
492 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
493 |
+
datasets_list.append(subdataset)
|
494 |
+
valid_dataset = ConcatDataset(datasets_list)
|
495 |
+
valid_collate = Collator(self.cfg)
|
496 |
+
_, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
|
497 |
+
self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
|
498 |
+
self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
|
499 |
+
valid_loader = DataLoader(
|
500 |
+
valid_dataset,
|
501 |
+
collate_fn=valid_collate,
|
502 |
+
batch_sampler=batch_sampler,
|
503 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
504 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
505 |
+
)
|
506 |
+
return train_loader, valid_loader
|
507 |
+
|
508 |
+
@staticmethod
|
509 |
+
def _set_random_seed(seed):
|
510 |
+
r"""Set random seed for all possible random modules."""
|
511 |
+
random.seed(seed)
|
512 |
+
np.random.seed(seed)
|
513 |
+
torch.random.manual_seed(seed)
|
514 |
+
|
515 |
+
def _check_nan(self, loss, y_pred, y_gt):
|
516 |
+
if torch.any(torch.isnan(loss)):
|
517 |
+
self.logger.error("Fatal Error: Training is down since loss has Nan!")
|
518 |
+
self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
|
519 |
+
|
520 |
+
### y_pred ###
|
521 |
+
if torch.any(torch.isnan(y_pred)):
|
522 |
+
self.logger.error(
|
523 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
524 |
+
)
|
525 |
+
self.logger.error(f"y_pred: {y_pred}", in_order=True)
|
526 |
+
else:
|
527 |
+
self.logger.debug(
|
528 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
529 |
+
)
|
530 |
+
self.logger.debug(f"y_pred: {y_pred}", in_order=True)
|
531 |
+
|
532 |
+
### y_gt ###
|
533 |
+
if torch.any(torch.isnan(y_gt)):
|
534 |
+
self.logger.error(
|
535 |
+
f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
536 |
+
)
|
537 |
+
self.logger.error(f"y_gt: {y_gt}", in_order=True)
|
538 |
+
else:
|
539 |
+
self.logger.debug(
|
540 |
+
f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
541 |
+
)
|
542 |
+
self.logger.debug(f"y_gt: {y_gt}", in_order=True)
|
543 |
+
|
544 |
+
self.accelerator.end_training()
|
545 |
+
raise RuntimeError("Loss has Nan! See log for more info.")
|
546 |
+
|
547 |
+
### Protected methods end ###
|
548 |
+
|
549 |
+
## Following are private methods ##
|
550 |
+
def _build_optimizer(self):
|
551 |
+
r"""Build optimizer for model."""
|
552 |
+
# Make case-insensitive matching
|
553 |
+
if self.cfg.train.optimizer.lower() == "adadelta":
|
554 |
+
optimizer = torch.optim.Adadelta(
|
555 |
+
self.model.parameters(), **self.cfg.train.adadelta
|
556 |
+
)
|
557 |
+
self.logger.info("Using Adadelta optimizer.")
|
558 |
+
elif self.cfg.train.optimizer.lower() == "adagrad":
|
559 |
+
optimizer = torch.optim.Adagrad(
|
560 |
+
self.model.parameters(), **self.cfg.train.adagrad
|
561 |
+
)
|
562 |
+
self.logger.info("Using Adagrad optimizer.")
|
563 |
+
elif self.cfg.train.optimizer.lower() == "adam":
|
564 |
+
optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
|
565 |
+
self.logger.info("Using Adam optimizer.")
|
566 |
+
elif self.cfg.train.optimizer.lower() == "adamw":
|
567 |
+
optimizer = torch.optim.AdamW(
|
568 |
+
self.model.parameters(), **self.cfg.train.adamw
|
569 |
+
)
|
570 |
+
elif self.cfg.train.optimizer.lower() == "sparseadam":
|
571 |
+
optimizer = torch.optim.SparseAdam(
|
572 |
+
self.model.parameters(), **self.cfg.train.sparseadam
|
573 |
+
)
|
574 |
+
elif self.cfg.train.optimizer.lower() == "adamax":
|
575 |
+
optimizer = torch.optim.Adamax(
|
576 |
+
self.model.parameters(), **self.cfg.train.adamax
|
577 |
+
)
|
578 |
+
elif self.cfg.train.optimizer.lower() == "asgd":
|
579 |
+
optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
|
580 |
+
elif self.cfg.train.optimizer.lower() == "lbfgs":
|
581 |
+
optimizer = torch.optim.LBFGS(
|
582 |
+
self.model.parameters(), **self.cfg.train.lbfgs
|
583 |
+
)
|
584 |
+
elif self.cfg.train.optimizer.lower() == "nadam":
|
585 |
+
optimizer = torch.optim.NAdam(
|
586 |
+
self.model.parameters(), **self.cfg.train.nadam
|
587 |
+
)
|
588 |
+
elif self.cfg.train.optimizer.lower() == "radam":
|
589 |
+
optimizer = torch.optim.RAdam(
|
590 |
+
self.model.parameters(), **self.cfg.train.radam
|
591 |
+
)
|
592 |
+
elif self.cfg.train.optimizer.lower() == "rmsprop":
|
593 |
+
optimizer = torch.optim.RMSprop(
|
594 |
+
self.model.parameters(), **self.cfg.train.rmsprop
|
595 |
+
)
|
596 |
+
elif self.cfg.train.optimizer.lower() == "rprop":
|
597 |
+
optimizer = torch.optim.Rprop(
|
598 |
+
self.model.parameters(), **self.cfg.train.rprop
|
599 |
+
)
|
600 |
+
elif self.cfg.train.optimizer.lower() == "sgd":
|
601 |
+
optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
|
602 |
+
else:
|
603 |
+
raise NotImplementedError(
|
604 |
+
f"Optimizer {self.cfg.train.optimizer} not supported yet!"
|
605 |
+
)
|
606 |
+
return optimizer
|
607 |
+
|
608 |
+
def _build_scheduler(self):
|
609 |
+
r"""Build scheduler for optimizer."""
|
610 |
+
# Make case-insensitive matching
|
611 |
+
if self.cfg.train.scheduler.lower() == "lambdalr":
|
612 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
613 |
+
self.optimizer, **self.cfg.train.lambdalr
|
614 |
+
)
|
615 |
+
elif self.cfg.train.scheduler.lower() == "multiplicativelr":
|
616 |
+
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
|
617 |
+
self.optimizer, **self.cfg.train.multiplicativelr
|
618 |
+
)
|
619 |
+
elif self.cfg.train.scheduler.lower() == "steplr":
|
620 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
621 |
+
self.optimizer, **self.cfg.train.steplr
|
622 |
+
)
|
623 |
+
elif self.cfg.train.scheduler.lower() == "multisteplr":
|
624 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
625 |
+
self.optimizer, **self.cfg.train.multisteplr
|
626 |
+
)
|
627 |
+
elif self.cfg.train.scheduler.lower() == "constantlr":
|
628 |
+
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
629 |
+
self.optimizer, **self.cfg.train.constantlr
|
630 |
+
)
|
631 |
+
elif self.cfg.train.scheduler.lower() == "linearlr":
|
632 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(
|
633 |
+
self.optimizer, **self.cfg.train.linearlr
|
634 |
+
)
|
635 |
+
elif self.cfg.train.scheduler.lower() == "exponentiallr":
|
636 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
637 |
+
self.optimizer, **self.cfg.train.exponentiallr
|
638 |
+
)
|
639 |
+
elif self.cfg.train.scheduler.lower() == "polynomiallr":
|
640 |
+
scheduler = torch.optim.lr_scheduler.PolynomialLR(
|
641 |
+
self.optimizer, **self.cfg.train.polynomiallr
|
642 |
+
)
|
643 |
+
elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
|
644 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
645 |
+
self.optimizer, **self.cfg.train.cosineannealinglr
|
646 |
+
)
|
647 |
+
elif self.cfg.train.scheduler.lower() == "sequentiallr":
|
648 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
649 |
+
self.optimizer, **self.cfg.train.sequentiallr
|
650 |
+
)
|
651 |
+
elif self.cfg.train.scheduler.lower() == "reducelronplateau":
|
652 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
653 |
+
self.optimizer, **self.cfg.train.reducelronplateau
|
654 |
+
)
|
655 |
+
elif self.cfg.train.scheduler.lower() == "cycliclr":
|
656 |
+
scheduler = torch.optim.lr_scheduler.CyclicLR(
|
657 |
+
self.optimizer, **self.cfg.train.cycliclr
|
658 |
+
)
|
659 |
+
elif self.cfg.train.scheduler.lower() == "onecyclelr":
|
660 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
661 |
+
self.optimizer, **self.cfg.train.onecyclelr
|
662 |
+
)
|
663 |
+
elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
|
664 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
665 |
+
self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
|
666 |
+
)
|
667 |
+
elif self.cfg.train.scheduler.lower() == "noamlr":
|
668 |
+
scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
|
669 |
+
else:
|
670 |
+
raise NotImplementedError(
|
671 |
+
f"Scheduler {self.cfg.train.scheduler} not supported yet!"
|
672 |
+
)
|
673 |
+
return scheduler
|
674 |
+
|
675 |
+
def _init_accelerator(self):
|
676 |
+
self.exp_dir = os.path.join(
|
677 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
678 |
+
)
|
679 |
+
project_config = ProjectConfiguration(
|
680 |
+
project_dir=self.exp_dir,
|
681 |
+
logging_dir=os.path.join(self.exp_dir, "log"),
|
682 |
+
)
|
683 |
+
self.accelerator = accelerate.Accelerator(
|
684 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
685 |
+
log_with=self.cfg.train.tracker,
|
686 |
+
project_config=project_config,
|
687 |
+
)
|
688 |
+
if self.accelerator.is_main_process:
|
689 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
690 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
691 |
+
with self.accelerator.main_process_first():
|
692 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
693 |
+
|
694 |
+
def __check_basic_configs(self):
|
695 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
696 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
697 |
+
self.logger.error(
|
698 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
699 |
+
)
|
700 |
+
self.accelerator.end_training()
|
701 |
+
raise ValueError(
|
702 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
703 |
+
)
|
704 |
+
# TODO: check other values
|
705 |
+
|
706 |
+
@staticmethod
|
707 |
+
def __count_parameters(model):
|
708 |
+
model_param = 0.0
|
709 |
+
if isinstance(model, dict):
|
710 |
+
for key, value in model.items():
|
711 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
712 |
+
else:
|
713 |
+
model_param = sum(p.numel() for p in model.parameters())
|
714 |
+
return model_param
|
715 |
+
|
716 |
+
def __dump_cfg(self, path):
|
717 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
718 |
+
json5.dump(
|
719 |
+
self.cfg,
|
720 |
+
open(path, "w"),
|
721 |
+
indent=4,
|
722 |
+
sort_keys=True,
|
723 |
+
ensure_ascii=False,
|
724 |
+
quote_keys=True,
|
725 |
+
)
|
726 |
+
|
727 |
+
### Private methods end ###
|
models/codec/__init__.py
ADDED
File without changes
|
models/codec/amphion_codec/codec.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange
|
12 |
+
from torch.nn.utils import weight_norm
|
13 |
+
|
14 |
+
from models.codec.amphion_codec.quantize import (
|
15 |
+
ResidualVQ,
|
16 |
+
VectorQuantize,
|
17 |
+
FactorizedVectorQuantize,
|
18 |
+
LookupFreeQuantize,
|
19 |
+
)
|
20 |
+
|
21 |
+
from models.codec.amphion_codec.vocos import Vocos
|
22 |
+
|
23 |
+
|
24 |
+
def WNConv1d(*args, **kwargs):
|
25 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
26 |
+
|
27 |
+
|
28 |
+
def WNConvTranspose1d(*args, **kwargs):
|
29 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
30 |
+
|
31 |
+
|
32 |
+
# Scripting this brings model speed up 1.4x
|
33 |
+
@torch.jit.script
|
34 |
+
def snake(x, alpha):
|
35 |
+
shape = x.shape
|
36 |
+
x = x.reshape(shape[0], shape[1], -1)
|
37 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
38 |
+
x = x.reshape(shape)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Snake1d(nn.Module):
|
43 |
+
def __init__(self, channels):
|
44 |
+
super().__init__()
|
45 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return snake(x, self.alpha)
|
49 |
+
|
50 |
+
|
51 |
+
def init_weights(m):
|
52 |
+
if isinstance(m, nn.Conv1d):
|
53 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
54 |
+
nn.init.constant_(m.bias, 0)
|
55 |
+
if isinstance(m, nn.Linear):
|
56 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
57 |
+
nn.init.constant_(m.bias, 0)
|
58 |
+
|
59 |
+
|
60 |
+
class ResidualUnit(nn.Module):
|
61 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
62 |
+
super().__init__()
|
63 |
+
pad = ((7 - 1) * dilation) // 2
|
64 |
+
self.block = nn.Sequential(
|
65 |
+
Snake1d(dim),
|
66 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
67 |
+
Snake1d(dim),
|
68 |
+
WNConv1d(dim, dim, kernel_size=1),
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
y = self.block(x)
|
73 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
74 |
+
if pad > 0:
|
75 |
+
x = x[..., pad:-pad]
|
76 |
+
return x + y
|
77 |
+
|
78 |
+
|
79 |
+
class EncoderBlock(nn.Module):
|
80 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
81 |
+
super().__init__()
|
82 |
+
self.block = nn.Sequential(
|
83 |
+
ResidualUnit(dim // 2, dilation=1),
|
84 |
+
ResidualUnit(dim // 2, dilation=3),
|
85 |
+
ResidualUnit(dim // 2, dilation=9),
|
86 |
+
Snake1d(dim // 2),
|
87 |
+
WNConv1d(
|
88 |
+
dim // 2,
|
89 |
+
dim,
|
90 |
+
kernel_size=2 * stride,
|
91 |
+
stride=stride,
|
92 |
+
padding=math.ceil(stride / 2),
|
93 |
+
),
|
94 |
+
)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
return self.block(x)
|
98 |
+
|
99 |
+
|
100 |
+
class CodecEncoder(nn.Module):
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
d_model: int = 64,
|
104 |
+
up_ratios: list = [4, 5, 5, 6],
|
105 |
+
out_channels: int = 256,
|
106 |
+
use_tanh: bool = False,
|
107 |
+
cfg=None,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
d_model = cfg.d_model if cfg is not None else d_model
|
112 |
+
up_ratios = cfg.up_ratios if cfg is not None else up_ratios
|
113 |
+
out_channels = cfg.out_channels if cfg is not None else out_channels
|
114 |
+
use_tanh = cfg.use_tanh if cfg is not None else use_tanh
|
115 |
+
|
116 |
+
# Create first convolution
|
117 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
118 |
+
|
119 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
120 |
+
for stride in up_ratios:
|
121 |
+
d_model *= 2
|
122 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
123 |
+
|
124 |
+
# Create last convolution
|
125 |
+
self.block += [
|
126 |
+
Snake1d(d_model),
|
127 |
+
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
128 |
+
]
|
129 |
+
|
130 |
+
if use_tanh:
|
131 |
+
self.block += [nn.Tanh()]
|
132 |
+
|
133 |
+
# Wrap black into nn.Sequential
|
134 |
+
self.block = nn.Sequential(*self.block)
|
135 |
+
self.enc_dim = d_model
|
136 |
+
|
137 |
+
self.reset_parameters()
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
return self.block(x)
|
141 |
+
|
142 |
+
def reset_parameters(self):
|
143 |
+
self.apply(init_weights)
|
144 |
+
|
145 |
+
|
146 |
+
class DecoderBlock(nn.Module):
|
147 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
148 |
+
super().__init__()
|
149 |
+
self.block = nn.Sequential(
|
150 |
+
Snake1d(input_dim),
|
151 |
+
WNConvTranspose1d(
|
152 |
+
input_dim,
|
153 |
+
output_dim,
|
154 |
+
kernel_size=2 * stride,
|
155 |
+
stride=stride,
|
156 |
+
padding=stride // 2 + stride % 2,
|
157 |
+
output_padding=stride % 2,
|
158 |
+
),
|
159 |
+
ResidualUnit(output_dim, dilation=1),
|
160 |
+
ResidualUnit(output_dim, dilation=3),
|
161 |
+
ResidualUnit(output_dim, dilation=9),
|
162 |
+
)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
return self.block(x)
|
166 |
+
|
167 |
+
|
168 |
+
class CodecDecoder(nn.Module):
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
in_channels: int = 256,
|
172 |
+
upsample_initial_channel: int = 1536,
|
173 |
+
up_ratios: list = [5, 5, 4, 2],
|
174 |
+
num_quantizers: int = 8,
|
175 |
+
codebook_size: int = 1024,
|
176 |
+
codebook_dim: int = 256,
|
177 |
+
quantizer_type: str = "vq",
|
178 |
+
quantizer_dropout: float = 0.5,
|
179 |
+
commitment: float = 0.25,
|
180 |
+
codebook_loss_weight: float = 1.0,
|
181 |
+
use_l2_normlize: bool = False,
|
182 |
+
codebook_type: str = "euclidean",
|
183 |
+
kmeans_init: bool = False,
|
184 |
+
kmeans_iters: int = 10,
|
185 |
+
decay: float = 0.8,
|
186 |
+
eps: float = 1e-5,
|
187 |
+
threshold_ema_dead_code: int = 2,
|
188 |
+
weight_init: bool = False,
|
189 |
+
use_vocos: bool = False,
|
190 |
+
vocos_dim: int = 384,
|
191 |
+
vocos_intermediate_dim: int = 1152,
|
192 |
+
vocos_num_layers: int = 8,
|
193 |
+
n_fft: int = 800,
|
194 |
+
hop_size: int = 200,
|
195 |
+
padding: str = "same",
|
196 |
+
cfg=None,
|
197 |
+
):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
in_channels = (
|
201 |
+
cfg.in_channels
|
202 |
+
if cfg is not None and hasattr(cfg, "in_channels")
|
203 |
+
else in_channels
|
204 |
+
)
|
205 |
+
upsample_initial_channel = (
|
206 |
+
cfg.upsample_initial_channel
|
207 |
+
if cfg is not None and hasattr(cfg, "upsample_initial_channel")
|
208 |
+
else upsample_initial_channel
|
209 |
+
)
|
210 |
+
up_ratios = (
|
211 |
+
cfg.up_ratios
|
212 |
+
if cfg is not None and hasattr(cfg, "up_ratios")
|
213 |
+
else up_ratios
|
214 |
+
)
|
215 |
+
num_quantizers = (
|
216 |
+
cfg.num_quantizers
|
217 |
+
if cfg is not None and hasattr(cfg, "num_quantizers")
|
218 |
+
else num_quantizers
|
219 |
+
)
|
220 |
+
codebook_size = (
|
221 |
+
cfg.codebook_size
|
222 |
+
if cfg is not None and hasattr(cfg, "codebook_size")
|
223 |
+
else codebook_size
|
224 |
+
)
|
225 |
+
codebook_dim = (
|
226 |
+
cfg.codebook_dim
|
227 |
+
if cfg is not None and hasattr(cfg, "codebook_dim")
|
228 |
+
else codebook_dim
|
229 |
+
)
|
230 |
+
quantizer_type = (
|
231 |
+
cfg.quantizer_type
|
232 |
+
if cfg is not None and hasattr(cfg, "quantizer_type")
|
233 |
+
else quantizer_type
|
234 |
+
)
|
235 |
+
quantizer_dropout = (
|
236 |
+
cfg.quantizer_dropout
|
237 |
+
if cfg is not None and hasattr(cfg, "quantizer_dropout")
|
238 |
+
else quantizer_dropout
|
239 |
+
)
|
240 |
+
commitment = (
|
241 |
+
cfg.commitment
|
242 |
+
if cfg is not None and hasattr(cfg, "commitment")
|
243 |
+
else commitment
|
244 |
+
)
|
245 |
+
codebook_loss_weight = (
|
246 |
+
cfg.codebook_loss_weight
|
247 |
+
if cfg is not None and hasattr(cfg, "codebook_loss_weight")
|
248 |
+
else codebook_loss_weight
|
249 |
+
)
|
250 |
+
use_l2_normlize = (
|
251 |
+
cfg.use_l2_normlize
|
252 |
+
if cfg is not None and hasattr(cfg, "use_l2_normlize")
|
253 |
+
else use_l2_normlize
|
254 |
+
)
|
255 |
+
codebook_type = (
|
256 |
+
cfg.codebook_type
|
257 |
+
if cfg is not None and hasattr(cfg, "codebook_type")
|
258 |
+
else codebook_type
|
259 |
+
)
|
260 |
+
kmeans_init = (
|
261 |
+
cfg.kmeans_init
|
262 |
+
if cfg is not None and hasattr(cfg, "kmeans_init")
|
263 |
+
else kmeans_init
|
264 |
+
)
|
265 |
+
kmeans_iters = (
|
266 |
+
cfg.kmeans_iters
|
267 |
+
if cfg is not None and hasattr(cfg, "kmeans_iters")
|
268 |
+
else kmeans_iters
|
269 |
+
)
|
270 |
+
decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
|
271 |
+
eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
|
272 |
+
threshold_ema_dead_code = (
|
273 |
+
cfg.threshold_ema_dead_code
|
274 |
+
if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
|
275 |
+
else threshold_ema_dead_code
|
276 |
+
)
|
277 |
+
weight_init = (
|
278 |
+
cfg.weight_init
|
279 |
+
if cfg is not None and hasattr(cfg, "weight_init")
|
280 |
+
else weight_init
|
281 |
+
)
|
282 |
+
use_vocos = (
|
283 |
+
cfg.use_vocos
|
284 |
+
if cfg is not None and hasattr(cfg, "use_vocos")
|
285 |
+
else use_vocos
|
286 |
+
)
|
287 |
+
vocos_dim = (
|
288 |
+
cfg.vocos_dim
|
289 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
290 |
+
else vocos_dim
|
291 |
+
)
|
292 |
+
vocos_intermediate_dim = (
|
293 |
+
cfg.vocos_intermediate_dim
|
294 |
+
if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
|
295 |
+
else vocos_intermediate_dim
|
296 |
+
)
|
297 |
+
vocos_num_layers = (
|
298 |
+
cfg.vocos_num_layers
|
299 |
+
if cfg is not None and hasattr(cfg, "vocos_num_layers")
|
300 |
+
else vocos_num_layers
|
301 |
+
)
|
302 |
+
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
303 |
+
hop_size = (
|
304 |
+
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
305 |
+
)
|
306 |
+
padding = (
|
307 |
+
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
308 |
+
)
|
309 |
+
|
310 |
+
if quantizer_type == "vq":
|
311 |
+
self.quantizer = ResidualVQ(
|
312 |
+
input_dim=in_channels,
|
313 |
+
num_quantizers=num_quantizers,
|
314 |
+
codebook_size=codebook_size,
|
315 |
+
codebook_dim=codebook_dim,
|
316 |
+
quantizer_type=quantizer_type,
|
317 |
+
quantizer_dropout=quantizer_dropout,
|
318 |
+
commitment=commitment,
|
319 |
+
codebook_loss_weight=codebook_loss_weight,
|
320 |
+
use_l2_normlize=use_l2_normlize,
|
321 |
+
codebook_type=codebook_type,
|
322 |
+
kmeans_init=kmeans_init,
|
323 |
+
kmeans_iters=kmeans_iters,
|
324 |
+
decay=decay,
|
325 |
+
eps=eps,
|
326 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
327 |
+
weight_init=weight_init,
|
328 |
+
)
|
329 |
+
elif quantizer_type == "fvq":
|
330 |
+
self.quantizer = ResidualVQ(
|
331 |
+
input_dim=in_channels,
|
332 |
+
num_quantizers=num_quantizers,
|
333 |
+
codebook_size=codebook_size,
|
334 |
+
codebook_dim=codebook_dim,
|
335 |
+
quantizer_type=quantizer_type,
|
336 |
+
quantizer_dropout=quantizer_dropout,
|
337 |
+
commitment=commitment,
|
338 |
+
codebook_loss_weight=codebook_loss_weight,
|
339 |
+
use_l2_normlize=use_l2_normlize,
|
340 |
+
)
|
341 |
+
elif quantizer_type == "lfq":
|
342 |
+
self.quantizer = ResidualVQ(
|
343 |
+
input_dim=in_channels,
|
344 |
+
num_quantizers=num_quantizers,
|
345 |
+
codebook_size=codebook_size,
|
346 |
+
codebook_dim=codebook_dim,
|
347 |
+
quantizer_type=quantizer_type,
|
348 |
+
)
|
349 |
+
else:
|
350 |
+
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
351 |
+
|
352 |
+
if not use_vocos:
|
353 |
+
# Add first conv layer
|
354 |
+
channels = upsample_initial_channel
|
355 |
+
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
356 |
+
|
357 |
+
# Add upsampling + MRF blocks
|
358 |
+
for i, stride in enumerate(up_ratios):
|
359 |
+
input_dim = channels // 2**i
|
360 |
+
output_dim = channels // 2 ** (i + 1)
|
361 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
362 |
+
|
363 |
+
# Add final conv layer
|
364 |
+
layers += [
|
365 |
+
Snake1d(output_dim),
|
366 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
367 |
+
nn.Tanh(),
|
368 |
+
]
|
369 |
+
|
370 |
+
self.model = nn.Sequential(*layers)
|
371 |
+
|
372 |
+
if use_vocos:
|
373 |
+
self.model = Vocos(
|
374 |
+
input_channels=in_channels,
|
375 |
+
dim=vocos_dim,
|
376 |
+
intermediate_dim=vocos_intermediate_dim,
|
377 |
+
num_layers=vocos_num_layers,
|
378 |
+
adanorm_num_embeddings=None,
|
379 |
+
n_fft=n_fft,
|
380 |
+
hop_size=hop_size,
|
381 |
+
padding=padding,
|
382 |
+
)
|
383 |
+
|
384 |
+
self.reset_parameters()
|
385 |
+
|
386 |
+
def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
|
387 |
+
"""
|
388 |
+
if vq is True, x = encoder output, then return quantized output;
|
389 |
+
else, x = quantized output, then return decoder output
|
390 |
+
"""
|
391 |
+
if vq is True:
|
392 |
+
if eval_vq:
|
393 |
+
self.quantizer.eval()
|
394 |
+
(
|
395 |
+
quantized_out,
|
396 |
+
all_indices,
|
397 |
+
all_commit_losses,
|
398 |
+
all_codebook_losses,
|
399 |
+
all_quantized,
|
400 |
+
) = self.quantizer(x, n_quantizers=n_quantizers)
|
401 |
+
return (
|
402 |
+
quantized_out,
|
403 |
+
all_indices,
|
404 |
+
all_commit_losses,
|
405 |
+
all_codebook_losses,
|
406 |
+
all_quantized,
|
407 |
+
)
|
408 |
+
|
409 |
+
return self.model(x)
|
410 |
+
|
411 |
+
def quantize(self, x, n_quantizers=None):
|
412 |
+
self.quantizer.eval()
|
413 |
+
quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
|
414 |
+
return quantized_out, vq
|
415 |
+
|
416 |
+
# TODO: check consistency of vq2emb and quantize
|
417 |
+
def vq2emb(self, vq, n_quantizers=None):
|
418 |
+
return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
|
419 |
+
|
420 |
+
def decode(self, x):
|
421 |
+
return self.model(x)
|
422 |
+
|
423 |
+
def latent2dist(self, x, n_quantizers=None):
|
424 |
+
return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
|
425 |
+
|
426 |
+
def reset_parameters(self):
|
427 |
+
self.apply(init_weights)
|
models/codec/amphion_codec/quantize/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
7 |
+
FactorizedVectorQuantize,
|
8 |
+
)
|
9 |
+
from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
10 |
+
from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
11 |
+
from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
|
models/codec/amphion_codec/quantize/factorized_vector_quantize.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange
|
11 |
+
from torch.nn.utils import weight_norm
|
12 |
+
|
13 |
+
|
14 |
+
def WNConv1d(*args, **kwargs):
|
15 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
16 |
+
|
17 |
+
|
18 |
+
def WNConvTranspose1d(*args, **kwargs):
|
19 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
20 |
+
|
21 |
+
|
22 |
+
class FactorizedVectorQuantize(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
input_dim,
|
26 |
+
codebook_size,
|
27 |
+
codebook_dim,
|
28 |
+
commitment=0.005,
|
29 |
+
codebook_loss_weight=1.0,
|
30 |
+
use_l2_normlize=True,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.input_dim = input_dim
|
34 |
+
self.codebook_size = codebook_size
|
35 |
+
self.codebook_dim = codebook_dim
|
36 |
+
self.commitment = commitment
|
37 |
+
self.codebook_loss_weight = codebook_loss_weight
|
38 |
+
self.use_l2_normlize = use_l2_normlize
|
39 |
+
|
40 |
+
if self.input_dim != self.codebook_dim:
|
41 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
42 |
+
self.out_project = WNConv1d(
|
43 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
44 |
+
)
|
45 |
+
|
46 |
+
else:
|
47 |
+
self.in_project = nn.Identity()
|
48 |
+
self.out_project = nn.Identity()
|
49 |
+
|
50 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
51 |
+
|
52 |
+
def forward(self, z):
|
53 |
+
"""
|
54 |
+
Parameters
|
55 |
+
----------
|
56 |
+
z: torch.Tensor[B x D x T]
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
z_q: torch.Tensor[B x D x T]
|
61 |
+
Quantized continuous representation of input
|
62 |
+
commit_loss: Tensor[B]
|
63 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
64 |
+
codebook_loss: Tensor[B]
|
65 |
+
Codebook loss to update the codebook
|
66 |
+
indices: torch.Tensor[B x T]
|
67 |
+
Codebook indices (quantized discrete representation of input)
|
68 |
+
z_e: torch.Tensor[B x D x T]
|
69 |
+
Projected latents (continuous representation of input before quantization)
|
70 |
+
"""
|
71 |
+
|
72 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
73 |
+
z_e = self.in_project(z)
|
74 |
+
z_q, indices = self.decode_latents(z_e)
|
75 |
+
|
76 |
+
# Compute commitment loss and codebook loss
|
77 |
+
if self.training:
|
78 |
+
commit_loss = (
|
79 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
80 |
+
* self.commitment
|
81 |
+
)
|
82 |
+
codebook_loss = (
|
83 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
84 |
+
* self.codebook_loss_weight
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
88 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
89 |
+
|
90 |
+
z_q = z_e + (z_q - z_e).detach()
|
91 |
+
|
92 |
+
z_q = self.out_project(z_q)
|
93 |
+
|
94 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
95 |
+
|
96 |
+
def embed_code(self, embed_id):
|
97 |
+
return F.embedding(embed_id, self.codebook.weight)
|
98 |
+
|
99 |
+
def decode_code(self, embed_id):
|
100 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
101 |
+
|
102 |
+
def decode_latents(self, latents):
|
103 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
104 |
+
codebook = self.codebook.weight
|
105 |
+
|
106 |
+
# L2 normalize encodings and codebook
|
107 |
+
if self.use_l2_normlize:
|
108 |
+
encodings = F.normalize(encodings)
|
109 |
+
codebook = F.normalize(codebook)
|
110 |
+
|
111 |
+
# Compute euclidean distance between encodings and codebook,
|
112 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
113 |
+
dist = (
|
114 |
+
encodings.pow(2).sum(1, keepdim=True)
|
115 |
+
- 2 * encodings @ codebook.t()
|
116 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
117 |
+
)
|
118 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
119 |
+
z_q = self.decode_code(indices)
|
120 |
+
|
121 |
+
return z_q, indices
|
122 |
+
|
123 |
+
def vq2emb(self, vq, out_proj=True):
|
124 |
+
emb = self.decode_code(vq)
|
125 |
+
if out_proj:
|
126 |
+
emb = self.out_project(emb)
|
127 |
+
return emb
|
128 |
+
|
129 |
+
def latent2dist(self, latents):
|
130 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
131 |
+
codebook = self.codebook.weight
|
132 |
+
|
133 |
+
# L2 normalize encodings and codebook
|
134 |
+
if self.use_l2_normlize:
|
135 |
+
encodings = F.normalize(encodings)
|
136 |
+
codebook = F.normalize(codebook)
|
137 |
+
|
138 |
+
# Compute euclidean distance between encodings and codebook,
|
139 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
140 |
+
dist = (
|
141 |
+
encodings.pow(2).sum(1, keepdim=True)
|
142 |
+
- 2 * encodings @ codebook.t()
|
143 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
144 |
+
) # (b*t, k)
|
145 |
+
|
146 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
147 |
+
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
|
148 |
+
z_q = self.decode_code(indices)
|
149 |
+
|
150 |
+
return -dist, indices, z_q
|
models/codec/amphion_codec/quantize/lookup_free_quantize.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange
|
11 |
+
from torch.nn.utils import weight_norm
|
12 |
+
|
13 |
+
|
14 |
+
def WNConv1d(*args, **kwargs):
|
15 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
16 |
+
|
17 |
+
|
18 |
+
def WNConvTranspose1d(*args, **kwargs):
|
19 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
20 |
+
|
21 |
+
|
22 |
+
class LookupFreeQuantize(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
input_dim,
|
26 |
+
codebook_size,
|
27 |
+
codebook_dim,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.input_dim = input_dim
|
31 |
+
self.codebook_size = codebook_size
|
32 |
+
self.codebook_dim = codebook_dim
|
33 |
+
|
34 |
+
assert 2**codebook_dim == codebook_size
|
35 |
+
|
36 |
+
if self.input_dim != self.codebook_dim:
|
37 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
38 |
+
self.out_project = WNConv1d(
|
39 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
40 |
+
)
|
41 |
+
|
42 |
+
else:
|
43 |
+
self.in_project = nn.Identity()
|
44 |
+
self.out_project = nn.Identity()
|
45 |
+
|
46 |
+
def forward(self, z):
|
47 |
+
z_e = self.in_project(z)
|
48 |
+
z_e = F.sigmoid(z_e)
|
49 |
+
|
50 |
+
z_q = z_e + (torch.round(z_e) - z_e).detach()
|
51 |
+
|
52 |
+
z_q = self.out_project(z_q)
|
53 |
+
|
54 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
55 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
56 |
+
|
57 |
+
bits = (
|
58 |
+
2
|
59 |
+
** torch.arange(self.codebook_dim, device=z.device)
|
60 |
+
.unsqueeze(0)
|
61 |
+
.unsqueeze(-1)
|
62 |
+
.long()
|
63 |
+
) # (1, d, 1)
|
64 |
+
indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
|
65 |
+
|
66 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
67 |
+
|
68 |
+
def vq2emb(self, vq, out_proj=True):
|
69 |
+
emb = torch.zeros(
|
70 |
+
vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
|
71 |
+
) # (B, d, T)
|
72 |
+
for i in range(self.codebook_dim):
|
73 |
+
emb[:, i, :] = (vq % 2).float()
|
74 |
+
vq = vq // 2
|
75 |
+
if out_proj:
|
76 |
+
emb = self.out_project(emb)
|
77 |
+
return emb
|
models/codec/amphion_codec/quantize/residual_vq.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange
|
13 |
+
from torch.nn.utils import weight_norm
|
14 |
+
|
15 |
+
from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
16 |
+
FactorizedVectorQuantize,
|
17 |
+
)
|
18 |
+
from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
19 |
+
from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
20 |
+
|
21 |
+
|
22 |
+
class ResidualVQ(nn.Module):
|
23 |
+
"""
|
24 |
+
Introduced in SoundStream: An end2end neural audio codec
|
25 |
+
https://arxiv.org/abs/2107.03312
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
input_dim: int = 256,
|
31 |
+
num_quantizers: int = 8,
|
32 |
+
codebook_size: int = 1024,
|
33 |
+
codebook_dim: int = 256,
|
34 |
+
quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
|
35 |
+
quantizer_dropout: float = 0.5,
|
36 |
+
**kwargs,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.input_dim = input_dim
|
41 |
+
self.num_quantizers = num_quantizers
|
42 |
+
self.codebook_size = codebook_size
|
43 |
+
self.codebook_dim = codebook_dim
|
44 |
+
self.quantizer_type = quantizer_type
|
45 |
+
self.quantizer_dropout = quantizer_dropout
|
46 |
+
|
47 |
+
if quantizer_type == "vq":
|
48 |
+
VQ = VectorQuantize
|
49 |
+
elif quantizer_type == "fvq":
|
50 |
+
VQ = FactorizedVectorQuantize
|
51 |
+
elif quantizer_type == "lfq":
|
52 |
+
VQ = LookupFreeQuantize
|
53 |
+
else:
|
54 |
+
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
55 |
+
|
56 |
+
self.quantizers = nn.ModuleList(
|
57 |
+
[
|
58 |
+
VQ(
|
59 |
+
input_dim=input_dim,
|
60 |
+
codebook_size=codebook_size,
|
61 |
+
codebook_dim=codebook_dim,
|
62 |
+
**kwargs,
|
63 |
+
)
|
64 |
+
for _ in range(num_quantizers)
|
65 |
+
]
|
66 |
+
)
|
67 |
+
|
68 |
+
def forward(self, z, n_quantizers: int = None):
|
69 |
+
"""
|
70 |
+
Parameters
|
71 |
+
----------
|
72 |
+
z : Tensor[B x D x T]
|
73 |
+
n_quantizers : int, optional
|
74 |
+
No. of quantizers to use
|
75 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
76 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
77 |
+
when in training mode, and a random number of quantizers is used.
|
78 |
+
Returns
|
79 |
+
-------
|
80 |
+
"quantized_out" : Tensor[B x D x T]
|
81 |
+
Quantized continuous representation of input
|
82 |
+
"all_indices" : Tensor[N x B x T]
|
83 |
+
Codebook indices for each codebook
|
84 |
+
(quantized discrete representation of input)
|
85 |
+
"all_commit_losses" : Tensor[N]
|
86 |
+
"all_codebook_losses" : Tensor[N]
|
87 |
+
"all_quantized" : Tensor[N x B x D x T]
|
88 |
+
"""
|
89 |
+
|
90 |
+
quantized_out = 0.0
|
91 |
+
residual = z
|
92 |
+
|
93 |
+
all_commit_losses = []
|
94 |
+
all_codebook_losses = []
|
95 |
+
all_indices = []
|
96 |
+
all_quantized = []
|
97 |
+
|
98 |
+
if n_quantizers is None:
|
99 |
+
n_quantizers = self.num_quantizers
|
100 |
+
|
101 |
+
if self.training:
|
102 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
|
103 |
+
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
|
104 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
105 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
106 |
+
n_quantizers = n_quantizers.to(z.device)
|
107 |
+
|
108 |
+
for i, quantizer in enumerate(self.quantizers):
|
109 |
+
if self.training is False and i >= n_quantizers:
|
110 |
+
break
|
111 |
+
|
112 |
+
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
113 |
+
residual
|
114 |
+
)
|
115 |
+
|
116 |
+
# Create mask to apply quantizer dropout
|
117 |
+
mask = (
|
118 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
119 |
+
)
|
120 |
+
quantized_out = quantized_out + z_q_i * mask[:, None, None]
|
121 |
+
residual = residual - z_q_i
|
122 |
+
|
123 |
+
commit_loss_i = (commit_loss_i * mask).mean()
|
124 |
+
codebook_loss_i = (codebook_loss_i * mask).mean()
|
125 |
+
|
126 |
+
all_commit_losses.append(commit_loss_i)
|
127 |
+
all_codebook_losses.append(codebook_loss_i)
|
128 |
+
all_indices.append(indices_i)
|
129 |
+
all_quantized.append(z_q_i)
|
130 |
+
|
131 |
+
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
|
132 |
+
torch.stack,
|
133 |
+
(all_commit_losses, all_codebook_losses, all_indices, all_quantized),
|
134 |
+
)
|
135 |
+
|
136 |
+
return (
|
137 |
+
quantized_out,
|
138 |
+
all_indices,
|
139 |
+
all_commit_losses,
|
140 |
+
all_codebook_losses,
|
141 |
+
all_quantized,
|
142 |
+
)
|
143 |
+
|
144 |
+
def vq2emb(self, vq, n_quantizers=None):
|
145 |
+
quantized_out = 0.0
|
146 |
+
if n_quantizers is None:
|
147 |
+
n_quantizers = self.num_quantizers
|
148 |
+
for idx, quantizer in enumerate(self.quantizers):
|
149 |
+
if idx >= n_quantizers:
|
150 |
+
break
|
151 |
+
quantized_out += quantizer.vq2emb(vq[idx])
|
152 |
+
return quantized_out
|
153 |
+
|
154 |
+
def latent2dist(self, z, n_quantizers=None):
|
155 |
+
quantized_out = 0.0
|
156 |
+
residual = z
|
157 |
+
|
158 |
+
all_dists = []
|
159 |
+
all_indices = []
|
160 |
+
|
161 |
+
if n_quantizers is None:
|
162 |
+
n_quantizers = self.num_quantizers
|
163 |
+
|
164 |
+
for i, quantizer in enumerate(self.quantizers):
|
165 |
+
if self.training is False and i >= n_quantizers:
|
166 |
+
break
|
167 |
+
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
|
168 |
+
all_dists.append(dist_i)
|
169 |
+
all_indices.append(indices_i)
|
170 |
+
|
171 |
+
quantized_out = quantized_out + z_q_i
|
172 |
+
residual = residual - z_q_i
|
173 |
+
|
174 |
+
all_dists = torch.stack(all_dists)
|
175 |
+
all_indices = torch.stack(all_indices)
|
176 |
+
|
177 |
+
return all_dists, all_indices
|
models/codec/amphion_codec/quantize/vector_quantize.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from torch.nn.utils import weight_norm
|
12 |
+
|
13 |
+
|
14 |
+
def WNConv1d(*args, **kwargs):
|
15 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
16 |
+
|
17 |
+
|
18 |
+
def WNConvTranspose1d(*args, **kwargs):
|
19 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
20 |
+
|
21 |
+
|
22 |
+
def l2norm(t):
|
23 |
+
return F.normalize(t, p=2, dim=-1)
|
24 |
+
|
25 |
+
|
26 |
+
def ema_inplace(moving_avg, new, decay):
|
27 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
28 |
+
|
29 |
+
|
30 |
+
def laplace_smoothing(x, n_categories, eps=1e-5):
|
31 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
32 |
+
|
33 |
+
|
34 |
+
def sample_vectors(samples, num):
|
35 |
+
num_samples, device = samples.shape[0], samples.device
|
36 |
+
|
37 |
+
if num_samples >= num:
|
38 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
39 |
+
else:
|
40 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
41 |
+
|
42 |
+
return samples[indices]
|
43 |
+
|
44 |
+
|
45 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
46 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
47 |
+
|
48 |
+
means = sample_vectors(samples, num_clusters)
|
49 |
+
|
50 |
+
for _ in range(num_iters):
|
51 |
+
if use_cosine_sim:
|
52 |
+
dists = samples @ means.t()
|
53 |
+
else:
|
54 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
55 |
+
means, "c d -> () c d"
|
56 |
+
)
|
57 |
+
dists = -(diffs**2).sum(dim=-1)
|
58 |
+
|
59 |
+
buckets = dists.max(dim=-1).indices
|
60 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
61 |
+
zero_mask = bins == 0
|
62 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
63 |
+
|
64 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
65 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
66 |
+
new_means = new_means / bins_min_clamped[..., None]
|
67 |
+
|
68 |
+
if use_cosine_sim:
|
69 |
+
new_means = l2norm(new_means)
|
70 |
+
|
71 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
72 |
+
|
73 |
+
return means, bins
|
74 |
+
|
75 |
+
|
76 |
+
class EuclideanCodebook(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
dim,
|
80 |
+
codebook_size,
|
81 |
+
kmeans_init=False,
|
82 |
+
kmeans_iters=10,
|
83 |
+
decay=0.8,
|
84 |
+
eps=1e-5,
|
85 |
+
threshold_ema_dead_code=2,
|
86 |
+
weight_init=False,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.decay = decay
|
91 |
+
init_fn = torch.randn if not weight_init else torch.zeros
|
92 |
+
embed = init_fn(codebook_size, dim)
|
93 |
+
|
94 |
+
if weight_init:
|
95 |
+
nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
|
96 |
+
|
97 |
+
self.codebook_size = codebook_size
|
98 |
+
self.kmeans_iters = kmeans_iters
|
99 |
+
self.eps = eps
|
100 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
101 |
+
|
102 |
+
self.register_buffer(
|
103 |
+
"initted", torch.Tensor([not kmeans_init])
|
104 |
+
) # if kmeans_init is True, then initted is False; otherwise, initted is True
|
105 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
106 |
+
self.register_buffer("embed", embed)
|
107 |
+
self.register_buffer("embed_avg", embed.clone())
|
108 |
+
|
109 |
+
def init_embed_(self, data):
|
110 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
111 |
+
self.embed.data.copy_(embed)
|
112 |
+
self.embed_avg.data.copy_(embed)
|
113 |
+
self.cluster_size.data.copy_(cluster_size)
|
114 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
115 |
+
|
116 |
+
def replace(self, samples, mask):
|
117 |
+
modified_codebook = torch.where(
|
118 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
119 |
+
)
|
120 |
+
self.embed.data.copy_(modified_codebook)
|
121 |
+
|
122 |
+
def expire_codes_(self, batch_samples):
|
123 |
+
if self.threshold_ema_dead_code == 0:
|
124 |
+
return
|
125 |
+
|
126 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
127 |
+
if not torch.any(expired_codes):
|
128 |
+
return
|
129 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
130 |
+
self.replace(batch_samples, mask=expired_codes)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
shape, dtype = x.shape, x.dtype
|
134 |
+
flatten = rearrange(x, "... d -> (...) d")
|
135 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
136 |
+
|
137 |
+
if not self.initted:
|
138 |
+
self.init_embed_(flatten)
|
139 |
+
|
140 |
+
dist = -(
|
141 |
+
flatten.pow(2).sum(1, keepdim=True)
|
142 |
+
- 2 * flatten @ embed
|
143 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
144 |
+
)
|
145 |
+
|
146 |
+
embed_ind = dist.max(dim=-1).indices
|
147 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
148 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
149 |
+
quantize = F.embedding(embed_ind, self.embed)
|
150 |
+
|
151 |
+
if self.training:
|
152 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
153 |
+
embed_sum = (
|
154 |
+
flatten.t() @ embed_onehot
|
155 |
+
) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
|
156 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
157 |
+
cluster_size = (
|
158 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
|
159 |
+
* self.cluster_size.sum()
|
160 |
+
)
|
161 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
162 |
+
self.embed.data.copy_(embed_normalized)
|
163 |
+
self.expire_codes_(x)
|
164 |
+
|
165 |
+
return quantize, embed_ind
|
166 |
+
|
167 |
+
def vq2emb(self, vq):
|
168 |
+
quantize = F.embedding(vq, self.embed)
|
169 |
+
return quantize
|
170 |
+
|
171 |
+
def latent2dist(self, x):
|
172 |
+
shape, dtype = x.shape, x.dtype
|
173 |
+
flatten = rearrange(x, "... d -> (...) d")
|
174 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
175 |
+
|
176 |
+
if not self.initted:
|
177 |
+
self.init_embed_(flatten)
|
178 |
+
|
179 |
+
dist = -(
|
180 |
+
flatten.pow(2).sum(1, keepdim=True)
|
181 |
+
- 2 * flatten @ embed
|
182 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
183 |
+
)
|
184 |
+
|
185 |
+
embed_ind = dist.max(dim=-1).indices
|
186 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
187 |
+
quantize = F.embedding(embed_ind, self.embed)
|
188 |
+
|
189 |
+
dist = dist.view(*shape[:-1], -1)
|
190 |
+
|
191 |
+
return dist, embed_ind, quantize
|
192 |
+
|
193 |
+
|
194 |
+
class SimpleCodebook(nn.Module):
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
dim,
|
198 |
+
codebook_size,
|
199 |
+
use_l2_normlize=False,
|
200 |
+
):
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
self.dim = dim
|
204 |
+
self.codebook_size = codebook_size
|
205 |
+
self.use_l2_normlize = use_l2_normlize
|
206 |
+
|
207 |
+
self.embed = nn.Embedding(self.codebook_size, self.dim)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
shape, dtype = x.shape, x.dtype
|
211 |
+
flatten = rearrange(x, "... d -> (...) d")
|
212 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
213 |
+
|
214 |
+
if self.use_l2_normlize:
|
215 |
+
flatten = F.normalize(flatten)
|
216 |
+
embed = F.normalize(embed)
|
217 |
+
|
218 |
+
dist = -(
|
219 |
+
flatten.pow(2).sum(1, keepdim=True)
|
220 |
+
- 2 * flatten @ embed
|
221 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
222 |
+
)
|
223 |
+
|
224 |
+
embed_ind = dist.max(dim=-1).indices
|
225 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
226 |
+
quantize = F.embedding(embed_ind, self.embed)
|
227 |
+
|
228 |
+
return quantize, embed_ind
|
229 |
+
|
230 |
+
def vq2emb(self, vq):
|
231 |
+
quantize = F.embedding(vq, self.embed.weight)
|
232 |
+
return quantize
|
233 |
+
|
234 |
+
def latent2dist(self, x):
|
235 |
+
shape, dtype = x.shape, x.dtype
|
236 |
+
flatten = rearrange(x, "... d -> (...) d")
|
237 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
238 |
+
|
239 |
+
if self.use_l2_normlize:
|
240 |
+
flatten = F.normalize(flatten)
|
241 |
+
embed = F.normalize(embed)
|
242 |
+
|
243 |
+
dist = -(
|
244 |
+
flatten.pow(2).sum(1, keepdim=True)
|
245 |
+
- 2 * flatten @ embed
|
246 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
247 |
+
)
|
248 |
+
|
249 |
+
embed_ind = dist.max(dim=-1).indices
|
250 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
251 |
+
quantize = F.embedding(embed_ind, self.embed)
|
252 |
+
|
253 |
+
dist = dist.view(*shape[:-1], -1)
|
254 |
+
|
255 |
+
return dist, embed_ind, quantize
|
256 |
+
|
257 |
+
|
258 |
+
class VectorQuantize(nn.Module):
|
259 |
+
"""Vector quantization and factorized vecotor quantization implementation
|
260 |
+
Args:
|
261 |
+
input_dim (int): Dimension of input.
|
262 |
+
codebook_size (int): Codebook size.
|
263 |
+
codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
|
264 |
+
if use codebook_type == "euclidean", otherwise, if you want to use
|
265 |
+
factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
|
266 |
+
commitment (float): Weight for commitment loss.
|
267 |
+
use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
|
268 |
+
we suggest use it as True if you want to use factorized vector quantization
|
269 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
270 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
271 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
272 |
+
epsilon (float): Epsilon value for numerical stability.
|
273 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
274 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
275 |
+
randomly selected vector from the current batch.
|
276 |
+
"""
|
277 |
+
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
input_dim,
|
281 |
+
codebook_size,
|
282 |
+
codebook_dim,
|
283 |
+
commitment=0.005,
|
284 |
+
codebook_loss_weight=1.0,
|
285 |
+
use_l2_normlize=False,
|
286 |
+
codebook_type="euclidean", # "euclidean" or "simple"
|
287 |
+
kmeans_init=False,
|
288 |
+
kmeans_iters=10,
|
289 |
+
decay=0.8,
|
290 |
+
eps=1e-5,
|
291 |
+
threshold_ema_dead_code=2,
|
292 |
+
weight_init=False,
|
293 |
+
):
|
294 |
+
super().__init__()
|
295 |
+
self.input_dim = input_dim
|
296 |
+
self.codebook_size = codebook_size
|
297 |
+
self.codebook_dim = codebook_dim
|
298 |
+
self.commitment = commitment
|
299 |
+
self.codebook_loss_weight = codebook_loss_weight
|
300 |
+
self.use_l2_normlize = use_l2_normlize
|
301 |
+
self.codebook_type = codebook_type
|
302 |
+
self.kmeans_init = kmeans_init
|
303 |
+
self.kmeans_iters = kmeans_iters
|
304 |
+
self.decay = decay
|
305 |
+
self.eps = eps
|
306 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
307 |
+
self.weight_init = weight_init
|
308 |
+
|
309 |
+
if self.input_dim != self.codebook_dim:
|
310 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
311 |
+
self.out_project = WNConv1d(
|
312 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
313 |
+
)
|
314 |
+
|
315 |
+
else:
|
316 |
+
self.in_project = nn.Identity()
|
317 |
+
self.out_project = nn.Identity()
|
318 |
+
|
319 |
+
if self.codebook_type == "euclidean":
|
320 |
+
self.codebook = EuclideanCodebook(
|
321 |
+
self.codebook_dim,
|
322 |
+
codebook_size=self.codebook_size,
|
323 |
+
kmeans_init=self.kmeans_init,
|
324 |
+
kmeans_iters=self.kmeans_iters,
|
325 |
+
decay=self.decay,
|
326 |
+
eps=self.eps,
|
327 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
328 |
+
weight_init=self.weight_init,
|
329 |
+
)
|
330 |
+
elif self.codebook_type == "simple":
|
331 |
+
self.codebook = SimpleCodebook(
|
332 |
+
self.codebook_dim,
|
333 |
+
codebook_size=self.codebook_size,
|
334 |
+
use_l2_normlize=self.use_l2_normlize,
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
raise NotImplementedError(
|
338 |
+
f"codebook_type {self.codebook_type} is not implemented!"
|
339 |
+
)
|
340 |
+
|
341 |
+
def forward(self, z):
|
342 |
+
"""
|
343 |
+
Parameters
|
344 |
+
----------
|
345 |
+
z: torch.Tensor[B x D x T]
|
346 |
+
|
347 |
+
Returns
|
348 |
+
-------
|
349 |
+
z_q: torch.Tensor[B x D x T]
|
350 |
+
Quantized continuous representation of input
|
351 |
+
commit_loss: Tensor[B]
|
352 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
353 |
+
codebook_loss: Tensor[B]
|
354 |
+
Codebook loss to update the codebook
|
355 |
+
indices: torch.Tensor[B x T]
|
356 |
+
Codebook indices (quantized discrete representation of input)
|
357 |
+
z_e: torch.Tensor[B x D x T]
|
358 |
+
Projected latents (continuous representation of input before quantization)
|
359 |
+
"""
|
360 |
+
|
361 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
362 |
+
z_e = self.in_project(z)
|
363 |
+
z_q, indices = self.decode_latents(z_e)
|
364 |
+
|
365 |
+
# Compute commitment loss and codebook loss
|
366 |
+
if self.training:
|
367 |
+
commit_loss = (
|
368 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
369 |
+
* self.commitment
|
370 |
+
)
|
371 |
+
codebook_loss = (
|
372 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
373 |
+
* self.codebook_loss_weight
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
377 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
378 |
+
|
379 |
+
z_q = z_e + (z_q - z_e).detach()
|
380 |
+
|
381 |
+
z_q = self.out_project(z_q)
|
382 |
+
|
383 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
384 |
+
|
385 |
+
def decode_latents(self, latents):
|
386 |
+
encodings = rearrange(latents, "b d t -> b t d")
|
387 |
+
z_q, indices = self.codebook(encodings)
|
388 |
+
z_q = z_q.transpose(1, 2)
|
389 |
+
return z_q, indices
|
390 |
+
|
391 |
+
def vq2emb(self, vq, out_proj=True):
|
392 |
+
emb = self.codebook.vq2emb(vq)
|
393 |
+
emb = emb.transpose(1, 2)
|
394 |
+
if out_proj:
|
395 |
+
emb = self.out_project(emb)
|
396 |
+
return emb
|
397 |
+
|
398 |
+
def latent2dist(self, latents):
|
399 |
+
latents = rearrange(latents, "b d t -> b t d")
|
400 |
+
dist, embed_ind, quantize = self.codebook.latent2dist(latents)
|
401 |
+
return dist, embed_ind, quantize.transpose(1, 2)
|
models/codec/amphion_codec/vocos.py
ADDED
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import scipy
|
10 |
+
import torch
|
11 |
+
from torch import nn, view_as_real, view_as_complex
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
14 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
15 |
+
import librosa
|
16 |
+
|
17 |
+
|
18 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
x (Tensor): Input tensor.
|
24 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
28 |
+
"""
|
29 |
+
return torch.log(torch.clip(x, min=clip_val))
|
30 |
+
|
31 |
+
|
32 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
33 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
34 |
+
|
35 |
+
|
36 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
37 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
38 |
+
|
39 |
+
|
40 |
+
class STFT(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
n_fft: int,
|
44 |
+
hop_length: int,
|
45 |
+
win_length: int,
|
46 |
+
center=True,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.center = center
|
50 |
+
self.n_fft = n_fft
|
51 |
+
self.hop_length = hop_length
|
52 |
+
self.win_length = win_length
|
53 |
+
window = torch.hann_window(win_length)
|
54 |
+
self.register_buffer("window", window)
|
55 |
+
|
56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
57 |
+
# x: (B, T * hop_length)
|
58 |
+
|
59 |
+
if not self.center:
|
60 |
+
pad = self.win_length - self.hop_length
|
61 |
+
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
|
62 |
+
|
63 |
+
stft_spec = torch.stft(
|
64 |
+
x,
|
65 |
+
self.n_fft,
|
66 |
+
hop_length=self.hop_length,
|
67 |
+
win_length=self.win_length,
|
68 |
+
window=self.window,
|
69 |
+
center=self.center,
|
70 |
+
return_complex=False,
|
71 |
+
) # (B, n_fft // 2 + 1, T, 2)
|
72 |
+
|
73 |
+
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
|
74 |
+
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
|
75 |
+
|
76 |
+
log_mag = torch.log(
|
77 |
+
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
78 |
+
) # (B, n_fft // 2 + 1, T)
|
79 |
+
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
|
80 |
+
|
81 |
+
return log_mag, phase
|
82 |
+
|
83 |
+
|
84 |
+
class ISTFT(nn.Module):
|
85 |
+
"""
|
86 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
87 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
88 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
89 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
90 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
n_fft (int): Size of Fourier transform.
|
94 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
95 |
+
win_length (int): The size of window frame and STFT filter.
|
96 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
101 |
+
):
|
102 |
+
super().__init__()
|
103 |
+
if padding not in ["center", "same"]:
|
104 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
105 |
+
self.padding = padding
|
106 |
+
self.n_fft = n_fft
|
107 |
+
self.hop_length = hop_length
|
108 |
+
self.win_length = win_length
|
109 |
+
window = torch.hann_window(win_length)
|
110 |
+
self.register_buffer("window", window)
|
111 |
+
|
112 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
113 |
+
"""
|
114 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
118 |
+
N is the number of frequency bins, and T is the number of time frames.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
122 |
+
"""
|
123 |
+
if self.padding == "center":
|
124 |
+
# Fallback to pytorch native implementation
|
125 |
+
return torch.istft(
|
126 |
+
spec,
|
127 |
+
self.n_fft,
|
128 |
+
self.hop_length,
|
129 |
+
self.win_length,
|
130 |
+
self.window,
|
131 |
+
center=True,
|
132 |
+
)
|
133 |
+
elif self.padding == "same":
|
134 |
+
pad = (self.win_length - self.hop_length) // 2
|
135 |
+
else:
|
136 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
137 |
+
|
138 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
139 |
+
B, N, T = spec.shape
|
140 |
+
|
141 |
+
# Inverse FFT
|
142 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
143 |
+
ifft = ifft * self.window[None, :, None]
|
144 |
+
|
145 |
+
# Overlap and Add
|
146 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
147 |
+
y = torch.nn.functional.fold(
|
148 |
+
ifft,
|
149 |
+
output_size=(1, output_size),
|
150 |
+
kernel_size=(1, self.win_length),
|
151 |
+
stride=(1, self.hop_length),
|
152 |
+
)[:, 0, 0, pad:-pad]
|
153 |
+
|
154 |
+
# Window envelope
|
155 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
156 |
+
window_envelope = torch.nn.functional.fold(
|
157 |
+
window_sq,
|
158 |
+
output_size=(1, output_size),
|
159 |
+
kernel_size=(1, self.win_length),
|
160 |
+
stride=(1, self.hop_length),
|
161 |
+
).squeeze()[pad:-pad]
|
162 |
+
|
163 |
+
# Normalize
|
164 |
+
assert (window_envelope > 1e-11).all()
|
165 |
+
y = y / window_envelope
|
166 |
+
|
167 |
+
return y
|
168 |
+
|
169 |
+
|
170 |
+
class MDCT(nn.Module):
|
171 |
+
"""
|
172 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
frame_len (int): Length of the MDCT frame.
|
176 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
180 |
+
super().__init__()
|
181 |
+
if padding not in ["center", "same"]:
|
182 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
183 |
+
self.padding = padding
|
184 |
+
self.frame_len = frame_len
|
185 |
+
N = frame_len // 2
|
186 |
+
n0 = (N + 1) / 2
|
187 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
188 |
+
self.register_buffer("window", window)
|
189 |
+
|
190 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
191 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
192 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
193 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
194 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
195 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
196 |
+
|
197 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
198 |
+
"""
|
199 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
203 |
+
and T is the length of the audio.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
207 |
+
and N is the number of frequency bins.
|
208 |
+
"""
|
209 |
+
if self.padding == "center":
|
210 |
+
audio = torch.nn.functional.pad(
|
211 |
+
audio, (self.frame_len // 2, self.frame_len // 2)
|
212 |
+
)
|
213 |
+
elif self.padding == "same":
|
214 |
+
# hop_length is 1/2 frame_len
|
215 |
+
audio = torch.nn.functional.pad(
|
216 |
+
audio, (self.frame_len // 4, self.frame_len // 4)
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
220 |
+
|
221 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
222 |
+
N = self.frame_len // 2
|
223 |
+
x = x * self.window.expand(x.shape)
|
224 |
+
X = torch.fft.fft(
|
225 |
+
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
|
226 |
+
)[..., :N]
|
227 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
228 |
+
return torch.real(res) * np.sqrt(2)
|
229 |
+
|
230 |
+
|
231 |
+
class IMDCT(nn.Module):
|
232 |
+
"""
|
233 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
frame_len (int): Length of the MDCT frame.
|
237 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
238 |
+
"""
|
239 |
+
|
240 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
241 |
+
super().__init__()
|
242 |
+
if padding not in ["center", "same"]:
|
243 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
244 |
+
self.padding = padding
|
245 |
+
self.frame_len = frame_len
|
246 |
+
N = frame_len // 2
|
247 |
+
n0 = (N + 1) / 2
|
248 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
249 |
+
self.register_buffer("window", window)
|
250 |
+
|
251 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
252 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
253 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
254 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
255 |
+
|
256 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
257 |
+
"""
|
258 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
262 |
+
L is the number of frames, and N is the number of frequency bins.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
266 |
+
"""
|
267 |
+
B, L, N = X.shape
|
268 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
269 |
+
Y[..., :N] = X
|
270 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
271 |
+
y = torch.fft.ifft(
|
272 |
+
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
|
273 |
+
)
|
274 |
+
y = (
|
275 |
+
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
|
276 |
+
* np.sqrt(N)
|
277 |
+
* np.sqrt(2)
|
278 |
+
)
|
279 |
+
result = y * self.window.expand(y.shape)
|
280 |
+
output_size = (1, (L + 1) * N)
|
281 |
+
audio = torch.nn.functional.fold(
|
282 |
+
result.transpose(1, 2),
|
283 |
+
output_size=output_size,
|
284 |
+
kernel_size=(1, self.frame_len),
|
285 |
+
stride=(1, self.frame_len // 2),
|
286 |
+
)[:, 0, 0, :]
|
287 |
+
|
288 |
+
if self.padding == "center":
|
289 |
+
pad = self.frame_len // 2
|
290 |
+
elif self.padding == "same":
|
291 |
+
pad = self.frame_len // 4
|
292 |
+
else:
|
293 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
294 |
+
|
295 |
+
audio = audio[:, pad:-pad]
|
296 |
+
return audio
|
297 |
+
|
298 |
+
|
299 |
+
class FourierHead(nn.Module):
|
300 |
+
"""Base class for inverse fourier modules."""
|
301 |
+
|
302 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
303 |
+
"""
|
304 |
+
Args:
|
305 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
306 |
+
L is the sequence length, and H denotes the model dimension.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
310 |
+
"""
|
311 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
312 |
+
|
313 |
+
|
314 |
+
class ISTFTHead(FourierHead):
|
315 |
+
"""
|
316 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
dim (int): Hidden dimension of the model.
|
320 |
+
n_fft (int): Size of Fourier transform.
|
321 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
322 |
+
the resolution of the input features.
|
323 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
324 |
+
"""
|
325 |
+
|
326 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
327 |
+
super().__init__()
|
328 |
+
out_dim = n_fft + 2
|
329 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
330 |
+
self.istft = ISTFT(
|
331 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
332 |
+
)
|
333 |
+
|
334 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
335 |
+
"""
|
336 |
+
Forward pass of the ISTFTHead module.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
340 |
+
L is the sequence length, and H denotes the model dimension.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
344 |
+
"""
|
345 |
+
x = self.out(x).transpose(1, 2)
|
346 |
+
mag, p = x.chunk(2, dim=1)
|
347 |
+
mag = torch.exp(mag)
|
348 |
+
mag = torch.clip(
|
349 |
+
mag, max=1e2
|
350 |
+
) # safeguard to prevent excessively large magnitudes
|
351 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
352 |
+
x = torch.cos(p)
|
353 |
+
y = torch.sin(p)
|
354 |
+
# recalculating phase here does not produce anything new
|
355 |
+
# only costs time
|
356 |
+
# phase = torch.atan2(y, x)
|
357 |
+
# S = mag * torch.exp(phase * 1j)
|
358 |
+
# better directly produce the complex value
|
359 |
+
S = mag * (x + 1j * y)
|
360 |
+
audio = self.istft(S)
|
361 |
+
return audio
|
362 |
+
|
363 |
+
|
364 |
+
class IMDCTSymExpHead(FourierHead):
|
365 |
+
"""
|
366 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
367 |
+
|
368 |
+
Args:
|
369 |
+
dim (int): Hidden dimension of the model.
|
370 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
371 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
372 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
373 |
+
based on perceptual scaling. Defaults to None.
|
374 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
375 |
+
"""
|
376 |
+
|
377 |
+
def __init__(
|
378 |
+
self,
|
379 |
+
dim: int,
|
380 |
+
mdct_frame_len: int,
|
381 |
+
padding: str = "same",
|
382 |
+
sample_rate: Optional[int] = None,
|
383 |
+
clip_audio: bool = False,
|
384 |
+
):
|
385 |
+
super().__init__()
|
386 |
+
out_dim = mdct_frame_len // 2
|
387 |
+
self.out = nn.Linear(dim, out_dim)
|
388 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
389 |
+
self.clip_audio = clip_audio
|
390 |
+
|
391 |
+
if sample_rate is not None:
|
392 |
+
# optionally init the last layer following mel-scale
|
393 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
394 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
395 |
+
f_pts = _mel_to_hz(m_pts)
|
396 |
+
scale = 1 - (f_pts / f_pts.max())
|
397 |
+
|
398 |
+
with torch.no_grad():
|
399 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
400 |
+
|
401 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
402 |
+
"""
|
403 |
+
Forward pass of the IMDCTSymExpHead module.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
407 |
+
L is the sequence length, and H denotes the model dimension.
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
411 |
+
"""
|
412 |
+
x = self.out(x)
|
413 |
+
x = symexp(x)
|
414 |
+
x = torch.clip(
|
415 |
+
x, min=-1e2, max=1e2
|
416 |
+
) # safeguard to prevent excessively large magnitudes
|
417 |
+
audio = self.imdct(x)
|
418 |
+
if self.clip_audio:
|
419 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
420 |
+
|
421 |
+
return audio
|
422 |
+
|
423 |
+
|
424 |
+
class IMDCTCosHead(FourierHead):
|
425 |
+
"""
|
426 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
427 |
+
|
428 |
+
Args:
|
429 |
+
dim (int): Hidden dimension of the model.
|
430 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
431 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
432 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
433 |
+
"""
|
434 |
+
|
435 |
+
def __init__(
|
436 |
+
self,
|
437 |
+
dim: int,
|
438 |
+
mdct_frame_len: int,
|
439 |
+
padding: str = "same",
|
440 |
+
clip_audio: bool = False,
|
441 |
+
):
|
442 |
+
super().__init__()
|
443 |
+
self.clip_audio = clip_audio
|
444 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
445 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
446 |
+
|
447 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
448 |
+
"""
|
449 |
+
Forward pass of the IMDCTCosHead module.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
453 |
+
L is the sequence length, and H denotes the model dimension.
|
454 |
+
|
455 |
+
Returns:
|
456 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
457 |
+
"""
|
458 |
+
x = self.out(x)
|
459 |
+
m, p = x.chunk(2, dim=2)
|
460 |
+
m = torch.exp(m).clip(
|
461 |
+
max=1e2
|
462 |
+
) # safeguard to prevent excessively large magnitudes
|
463 |
+
audio = self.imdct(m * torch.cos(p))
|
464 |
+
if self.clip_audio:
|
465 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
466 |
+
return audio
|
467 |
+
|
468 |
+
|
469 |
+
class ConvNeXtBlock(nn.Module):
|
470 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
471 |
+
|
472 |
+
Args:
|
473 |
+
dim (int): Number of input channels.
|
474 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
475 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
476 |
+
Defaults to None.
|
477 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
478 |
+
None means non-conditional LayerNorm. Defaults to None.
|
479 |
+
"""
|
480 |
+
|
481 |
+
def __init__(
|
482 |
+
self,
|
483 |
+
dim: int,
|
484 |
+
intermediate_dim: int,
|
485 |
+
layer_scale_init_value: float,
|
486 |
+
adanorm_num_embeddings: Optional[int] = None,
|
487 |
+
):
|
488 |
+
super().__init__()
|
489 |
+
self.dwconv = nn.Conv1d(
|
490 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
491 |
+
) # depthwise conv
|
492 |
+
self.adanorm = adanorm_num_embeddings is not None
|
493 |
+
if adanorm_num_embeddings:
|
494 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
495 |
+
else:
|
496 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
497 |
+
self.pwconv1 = nn.Linear(
|
498 |
+
dim, intermediate_dim
|
499 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
500 |
+
self.act = nn.GELU()
|
501 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
502 |
+
self.gamma = (
|
503 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
504 |
+
if layer_scale_init_value > 0
|
505 |
+
else None
|
506 |
+
)
|
507 |
+
|
508 |
+
def forward(
|
509 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
510 |
+
) -> torch.Tensor:
|
511 |
+
residual = x
|
512 |
+
x = self.dwconv(x)
|
513 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
514 |
+
if self.adanorm:
|
515 |
+
assert cond_embedding_id is not None
|
516 |
+
x = self.norm(x, cond_embedding_id)
|
517 |
+
else:
|
518 |
+
x = self.norm(x)
|
519 |
+
x = self.pwconv1(x)
|
520 |
+
x = self.act(x)
|
521 |
+
x = self.pwconv2(x)
|
522 |
+
if self.gamma is not None:
|
523 |
+
x = self.gamma * x
|
524 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
525 |
+
|
526 |
+
x = residual + x
|
527 |
+
return x
|
528 |
+
|
529 |
+
|
530 |
+
class AdaLayerNorm(nn.Module):
|
531 |
+
"""
|
532 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
533 |
+
|
534 |
+
Args:
|
535 |
+
num_embeddings (int): Number of embeddings.
|
536 |
+
embedding_dim (int): Dimension of the embeddings.
|
537 |
+
"""
|
538 |
+
|
539 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
540 |
+
super().__init__()
|
541 |
+
self.eps = eps
|
542 |
+
self.dim = embedding_dim
|
543 |
+
self.scale = nn.Embedding(
|
544 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
545 |
+
)
|
546 |
+
self.shift = nn.Embedding(
|
547 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
548 |
+
)
|
549 |
+
torch.nn.init.ones_(self.scale.weight)
|
550 |
+
torch.nn.init.zeros_(self.shift.weight)
|
551 |
+
|
552 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
553 |
+
scale = self.scale(cond_embedding_id)
|
554 |
+
shift = self.shift(cond_embedding_id)
|
555 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
556 |
+
x = x * scale + shift
|
557 |
+
return x
|
558 |
+
|
559 |
+
|
560 |
+
class ResBlock1(nn.Module):
|
561 |
+
"""
|
562 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
563 |
+
but without upsampling layers.
|
564 |
+
|
565 |
+
Args:
|
566 |
+
dim (int): Number of input channels.
|
567 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
568 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
569 |
+
Defaults to (1, 3, 5).
|
570 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
571 |
+
Defaults to 0.1.
|
572 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
573 |
+
Defaults to None.
|
574 |
+
"""
|
575 |
+
|
576 |
+
def __init__(
|
577 |
+
self,
|
578 |
+
dim: int,
|
579 |
+
kernel_size: int = 3,
|
580 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
581 |
+
lrelu_slope: float = 0.1,
|
582 |
+
layer_scale_init_value: Optional[float] = None,
|
583 |
+
):
|
584 |
+
super().__init__()
|
585 |
+
self.lrelu_slope = lrelu_slope
|
586 |
+
self.convs1 = nn.ModuleList(
|
587 |
+
[
|
588 |
+
weight_norm(
|
589 |
+
nn.Conv1d(
|
590 |
+
dim,
|
591 |
+
dim,
|
592 |
+
kernel_size,
|
593 |
+
1,
|
594 |
+
dilation=dilation[0],
|
595 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
596 |
+
)
|
597 |
+
),
|
598 |
+
weight_norm(
|
599 |
+
nn.Conv1d(
|
600 |
+
dim,
|
601 |
+
dim,
|
602 |
+
kernel_size,
|
603 |
+
1,
|
604 |
+
dilation=dilation[1],
|
605 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
606 |
+
)
|
607 |
+
),
|
608 |
+
weight_norm(
|
609 |
+
nn.Conv1d(
|
610 |
+
dim,
|
611 |
+
dim,
|
612 |
+
kernel_size,
|
613 |
+
1,
|
614 |
+
dilation=dilation[2],
|
615 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
616 |
+
)
|
617 |
+
),
|
618 |
+
]
|
619 |
+
)
|
620 |
+
|
621 |
+
self.convs2 = nn.ModuleList(
|
622 |
+
[
|
623 |
+
weight_norm(
|
624 |
+
nn.Conv1d(
|
625 |
+
dim,
|
626 |
+
dim,
|
627 |
+
kernel_size,
|
628 |
+
1,
|
629 |
+
dilation=1,
|
630 |
+
padding=self.get_padding(kernel_size, 1),
|
631 |
+
)
|
632 |
+
),
|
633 |
+
weight_norm(
|
634 |
+
nn.Conv1d(
|
635 |
+
dim,
|
636 |
+
dim,
|
637 |
+
kernel_size,
|
638 |
+
1,
|
639 |
+
dilation=1,
|
640 |
+
padding=self.get_padding(kernel_size, 1),
|
641 |
+
)
|
642 |
+
),
|
643 |
+
weight_norm(
|
644 |
+
nn.Conv1d(
|
645 |
+
dim,
|
646 |
+
dim,
|
647 |
+
kernel_size,
|
648 |
+
1,
|
649 |
+
dilation=1,
|
650 |
+
padding=self.get_padding(kernel_size, 1),
|
651 |
+
)
|
652 |
+
),
|
653 |
+
]
|
654 |
+
)
|
655 |
+
|
656 |
+
self.gamma = nn.ParameterList(
|
657 |
+
[
|
658 |
+
(
|
659 |
+
nn.Parameter(
|
660 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
661 |
+
)
|
662 |
+
if layer_scale_init_value is not None
|
663 |
+
else None
|
664 |
+
),
|
665 |
+
(
|
666 |
+
nn.Parameter(
|
667 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
668 |
+
)
|
669 |
+
if layer_scale_init_value is not None
|
670 |
+
else None
|
671 |
+
),
|
672 |
+
(
|
673 |
+
nn.Parameter(
|
674 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
675 |
+
)
|
676 |
+
if layer_scale_init_value is not None
|
677 |
+
else None
|
678 |
+
),
|
679 |
+
]
|
680 |
+
)
|
681 |
+
|
682 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
683 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
684 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
685 |
+
xt = c1(xt)
|
686 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
687 |
+
xt = c2(xt)
|
688 |
+
if gamma is not None:
|
689 |
+
xt = gamma * xt
|
690 |
+
x = xt + x
|
691 |
+
return x
|
692 |
+
|
693 |
+
def remove_weight_norm(self):
|
694 |
+
for l in self.convs1:
|
695 |
+
remove_weight_norm(l)
|
696 |
+
for l in self.convs2:
|
697 |
+
remove_weight_norm(l)
|
698 |
+
|
699 |
+
@staticmethod
|
700 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
701 |
+
return int((kernel_size * dilation - dilation) / 2)
|
702 |
+
|
703 |
+
|
704 |
+
class Backbone(nn.Module):
|
705 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
706 |
+
|
707 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
708 |
+
"""
|
709 |
+
Args:
|
710 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
711 |
+
C denotes output features, and L is the sequence length.
|
712 |
+
|
713 |
+
Returns:
|
714 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
715 |
+
and H denotes the model dimension.
|
716 |
+
"""
|
717 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
718 |
+
|
719 |
+
|
720 |
+
class VocosBackbone(Backbone):
|
721 |
+
"""
|
722 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
723 |
+
|
724 |
+
Args:
|
725 |
+
input_channels (int): Number of input features channels.
|
726 |
+
dim (int): Hidden dimension of the model.
|
727 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
728 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
729 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
730 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
731 |
+
None means non-conditional model. Defaults to None.
|
732 |
+
"""
|
733 |
+
|
734 |
+
def __init__(
|
735 |
+
self,
|
736 |
+
input_channels: int,
|
737 |
+
dim: int,
|
738 |
+
intermediate_dim: int,
|
739 |
+
num_layers: int,
|
740 |
+
layer_scale_init_value: Optional[float] = None,
|
741 |
+
adanorm_num_embeddings: Optional[int] = None,
|
742 |
+
):
|
743 |
+
super().__init__()
|
744 |
+
self.input_channels = input_channels
|
745 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
746 |
+
self.adanorm = adanorm_num_embeddings is not None
|
747 |
+
if adanorm_num_embeddings:
|
748 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
749 |
+
else:
|
750 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
751 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
752 |
+
self.convnext = nn.ModuleList(
|
753 |
+
[
|
754 |
+
ConvNeXtBlock(
|
755 |
+
dim=dim,
|
756 |
+
intermediate_dim=intermediate_dim,
|
757 |
+
layer_scale_init_value=layer_scale_init_value,
|
758 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
759 |
+
)
|
760 |
+
for _ in range(num_layers)
|
761 |
+
]
|
762 |
+
)
|
763 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
764 |
+
self.apply(self._init_weights)
|
765 |
+
|
766 |
+
def _init_weights(self, m):
|
767 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
768 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
769 |
+
nn.init.constant_(m.bias, 0)
|
770 |
+
|
771 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
772 |
+
bandwidth_id = kwargs.get("bandwidth_id", None)
|
773 |
+
x = self.embed(x)
|
774 |
+
if self.adanorm:
|
775 |
+
assert bandwidth_id is not None
|
776 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
777 |
+
else:
|
778 |
+
x = self.norm(x.transpose(1, 2))
|
779 |
+
x = x.transpose(1, 2)
|
780 |
+
for conv_block in self.convnext:
|
781 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
782 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
783 |
+
return x
|
784 |
+
|
785 |
+
|
786 |
+
class VocosResNetBackbone(Backbone):
|
787 |
+
"""
|
788 |
+
Vocos backbone module built with ResBlocks.
|
789 |
+
|
790 |
+
Args:
|
791 |
+
input_channels (int): Number of input features channels.
|
792 |
+
dim (int): Hidden dimension of the model.
|
793 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
794 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
795 |
+
"""
|
796 |
+
|
797 |
+
def __init__(
|
798 |
+
self,
|
799 |
+
input_channels,
|
800 |
+
dim,
|
801 |
+
num_blocks,
|
802 |
+
layer_scale_init_value=None,
|
803 |
+
):
|
804 |
+
super().__init__()
|
805 |
+
self.input_channels = input_channels
|
806 |
+
self.embed = weight_norm(
|
807 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
808 |
+
)
|
809 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
810 |
+
self.resnet = nn.Sequential(
|
811 |
+
*[
|
812 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
813 |
+
for _ in range(num_blocks)
|
814 |
+
]
|
815 |
+
)
|
816 |
+
|
817 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
818 |
+
x = self.embed(x)
|
819 |
+
x = self.resnet(x)
|
820 |
+
x = x.transpose(1, 2)
|
821 |
+
return x
|
822 |
+
|
823 |
+
|
824 |
+
class Vocos(nn.Module):
|
825 |
+
def __init__(
|
826 |
+
self,
|
827 |
+
input_channels: int = 256,
|
828 |
+
dim: int = 384,
|
829 |
+
intermediate_dim: int = 1152,
|
830 |
+
num_layers: int = 8,
|
831 |
+
n_fft: int = 800,
|
832 |
+
hop_size: int = 200,
|
833 |
+
padding: str = "same",
|
834 |
+
adanorm_num_embeddings=None,
|
835 |
+
cfg=None,
|
836 |
+
):
|
837 |
+
super().__init__()
|
838 |
+
|
839 |
+
input_channels = (
|
840 |
+
cfg.input_channels
|
841 |
+
if cfg is not None and hasattr(cfg, "input_channels")
|
842 |
+
else input_channels
|
843 |
+
)
|
844 |
+
dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
|
845 |
+
intermediate_dim = (
|
846 |
+
cfg.intermediate_dim
|
847 |
+
if cfg is not None and hasattr(cfg, "intermediate_dim")
|
848 |
+
else intermediate_dim
|
849 |
+
)
|
850 |
+
num_layers = (
|
851 |
+
cfg.num_layers
|
852 |
+
if cfg is not None and hasattr(cfg, "num_layers")
|
853 |
+
else num_layers
|
854 |
+
)
|
855 |
+
adanorm_num_embeddings = (
|
856 |
+
cfg.adanorm_num_embeddings
|
857 |
+
if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
|
858 |
+
else adanorm_num_embeddings
|
859 |
+
)
|
860 |
+
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
861 |
+
hop_size = (
|
862 |
+
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
863 |
+
)
|
864 |
+
padding = (
|
865 |
+
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
866 |
+
)
|
867 |
+
|
868 |
+
self.backbone = VocosBackbone(
|
869 |
+
input_channels=input_channels,
|
870 |
+
dim=dim,
|
871 |
+
intermediate_dim=intermediate_dim,
|
872 |
+
num_layers=num_layers,
|
873 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
874 |
+
)
|
875 |
+
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
|
876 |
+
|
877 |
+
def forward(self, x):
|
878 |
+
x = self.backbone(x)
|
879 |
+
x = self.head(x)
|
880 |
+
|
881 |
+
return x[:, None, :]
|
models/codec/codec_dataset.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Iterable
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import torch.utils.data
|
10 |
+
from torch.nn.utils.rnn import pad_sequence
|
11 |
+
from utils.data_utils import *
|
12 |
+
from torch.utils.data import ConcatDataset, Dataset
|
13 |
+
|
14 |
+
|
15 |
+
class CodecDataset(torch.utils.data.Dataset):
|
16 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
17 |
+
"""
|
18 |
+
Args:
|
19 |
+
cfg: config
|
20 |
+
dataset: dataset name
|
21 |
+
is_valid: whether to use train or valid dataset
|
22 |
+
"""
|
23 |
+
assert isinstance(dataset, str)
|
24 |
+
|
25 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
26 |
+
|
27 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
28 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
29 |
+
self.metadata = self.get_metadata()
|
30 |
+
|
31 |
+
self.data_root = processed_data_dir
|
32 |
+
self.cfg = cfg
|
33 |
+
|
34 |
+
if cfg.preprocess.use_audio:
|
35 |
+
self.utt2audio_path = {}
|
36 |
+
for utt_info in self.metadata:
|
37 |
+
dataset = utt_info["Dataset"]
|
38 |
+
uid = utt_info["Uid"]
|
39 |
+
utt = "{}_{}".format(dataset, uid)
|
40 |
+
|
41 |
+
self.utt2audio_path[utt] = os.path.join(
|
42 |
+
cfg.preprocess.processed_dir,
|
43 |
+
dataset,
|
44 |
+
cfg.preprocess.audio_dir,
|
45 |
+
uid + ".npy",
|
46 |
+
)
|
47 |
+
elif cfg.preprocess.use_label:
|
48 |
+
self.utt2label_path = {}
|
49 |
+
for utt_info in self.metadata:
|
50 |
+
dataset = utt_info["Dataset"]
|
51 |
+
uid = utt_info["Uid"]
|
52 |
+
utt = "{}_{}".format(dataset, uid)
|
53 |
+
|
54 |
+
self.utt2label_path[utt] = os.path.join(
|
55 |
+
cfg.preprocess.processed_dir,
|
56 |
+
dataset,
|
57 |
+
cfg.preprocess.label_dir,
|
58 |
+
uid + ".npy",
|
59 |
+
)
|
60 |
+
elif cfg.preprocess.use_one_hot:
|
61 |
+
self.utt2one_hot_path = {}
|
62 |
+
for utt_info in self.metadata:
|
63 |
+
dataset = utt_info["Dataset"]
|
64 |
+
uid = utt_info["Uid"]
|
65 |
+
utt = "{}_{}".format(dataset, uid)
|
66 |
+
|
67 |
+
self.utt2one_hot_path[utt] = os.path.join(
|
68 |
+
cfg.preprocess.processed_dir,
|
69 |
+
dataset,
|
70 |
+
cfg.preprocess.one_hot_dir,
|
71 |
+
uid + ".npy",
|
72 |
+
)
|
73 |
+
|
74 |
+
if cfg.preprocess.use_mel:
|
75 |
+
self.utt2mel_path = {}
|
76 |
+
for utt_info in self.metadata:
|
77 |
+
dataset = utt_info["Dataset"]
|
78 |
+
uid = utt_info["Uid"]
|
79 |
+
utt = "{}_{}".format(dataset, uid)
|
80 |
+
|
81 |
+
self.utt2mel_path[utt] = os.path.join(
|
82 |
+
cfg.preprocess.processed_dir,
|
83 |
+
dataset,
|
84 |
+
cfg.preprocess.mel_dir,
|
85 |
+
uid + ".npy",
|
86 |
+
)
|
87 |
+
|
88 |
+
if cfg.preprocess.use_frame_pitch:
|
89 |
+
self.utt2frame_pitch_path = {}
|
90 |
+
for utt_info in self.metadata:
|
91 |
+
dataset = utt_info["Dataset"]
|
92 |
+
uid = utt_info["Uid"]
|
93 |
+
utt = "{}_{}".format(dataset, uid)
|
94 |
+
|
95 |
+
self.utt2frame_pitch_path[utt] = os.path.join(
|
96 |
+
cfg.preprocess.processed_dir,
|
97 |
+
dataset,
|
98 |
+
cfg.preprocess.pitch_dir,
|
99 |
+
uid + ".npy",
|
100 |
+
)
|
101 |
+
|
102 |
+
if cfg.preprocess.use_uv:
|
103 |
+
self.utt2uv_path = {}
|
104 |
+
for utt_info in self.metadata:
|
105 |
+
dataset = utt_info["Dataset"]
|
106 |
+
uid = utt_info["Uid"]
|
107 |
+
utt = "{}_{}".format(dataset, uid)
|
108 |
+
self.utt2uv_path[utt] = os.path.join(
|
109 |
+
cfg.preprocess.processed_dir,
|
110 |
+
dataset,
|
111 |
+
cfg.preprocess.uv_dir,
|
112 |
+
uid + ".npy",
|
113 |
+
)
|
114 |
+
|
115 |
+
if cfg.preprocess.use_amplitude_phase:
|
116 |
+
self.utt2logamp_path = {}
|
117 |
+
self.utt2pha_path = {}
|
118 |
+
self.utt2rea_path = {}
|
119 |
+
self.utt2imag_path = {}
|
120 |
+
for utt_info in self.metadata:
|
121 |
+
dataset = utt_info["Dataset"]
|
122 |
+
uid = utt_info["Uid"]
|
123 |
+
utt = "{}_{}".format(dataset, uid)
|
124 |
+
self.utt2logamp_path[utt] = os.path.join(
|
125 |
+
cfg.preprocess.processed_dir,
|
126 |
+
dataset,
|
127 |
+
cfg.preprocess.log_amplitude_dir,
|
128 |
+
uid + ".npy",
|
129 |
+
)
|
130 |
+
self.utt2pha_path[utt] = os.path.join(
|
131 |
+
cfg.preprocess.processed_dir,
|
132 |
+
dataset,
|
133 |
+
cfg.preprocess.phase_dir,
|
134 |
+
uid + ".npy",
|
135 |
+
)
|
136 |
+
self.utt2rea_path[utt] = os.path.join(
|
137 |
+
cfg.preprocess.processed_dir,
|
138 |
+
dataset,
|
139 |
+
cfg.preprocess.real_dir,
|
140 |
+
uid + ".npy",
|
141 |
+
)
|
142 |
+
self.utt2imag_path[utt] = os.path.join(
|
143 |
+
cfg.preprocess.processed_dir,
|
144 |
+
dataset,
|
145 |
+
cfg.preprocess.imaginary_dir,
|
146 |
+
uid + ".npy",
|
147 |
+
)
|
148 |
+
|
149 |
+
def __getitem__(self, index):
|
150 |
+
utt_info = self.metadata[index]
|
151 |
+
|
152 |
+
dataset = utt_info["Dataset"]
|
153 |
+
uid = utt_info["Uid"]
|
154 |
+
utt = "{}_{}".format(dataset, uid)
|
155 |
+
|
156 |
+
single_feature = dict()
|
157 |
+
|
158 |
+
if self.cfg.preprocess.use_mel:
|
159 |
+
mel = np.load(self.utt2mel_path[utt])
|
160 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
161 |
+
|
162 |
+
if "target_len" not in single_feature.keys():
|
163 |
+
single_feature["target_len"] = mel.shape[1]
|
164 |
+
|
165 |
+
single_feature["mel"] = mel
|
166 |
+
|
167 |
+
if self.cfg.preprocess.use_frame_pitch:
|
168 |
+
frame_pitch = np.load(self.utt2frame_pitch_path[utt])
|
169 |
+
|
170 |
+
if "target_len" not in single_feature.keys():
|
171 |
+
single_feature["target_len"] = len(frame_pitch)
|
172 |
+
|
173 |
+
aligned_frame_pitch = align_length(
|
174 |
+
frame_pitch, single_feature["target_len"]
|
175 |
+
)
|
176 |
+
|
177 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
178 |
+
|
179 |
+
if self.cfg.preprocess.use_audio:
|
180 |
+
audio = np.load(self.utt2audio_path[utt])
|
181 |
+
|
182 |
+
single_feature["audio"] = audio
|
183 |
+
|
184 |
+
return single_feature
|
185 |
+
|
186 |
+
def get_metadata(self):
|
187 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
188 |
+
metadata = json.load(f)
|
189 |
+
|
190 |
+
return metadata
|
191 |
+
|
192 |
+
def get_dataset_name(self):
|
193 |
+
return self.metadata[0]["Dataset"]
|
194 |
+
|
195 |
+
def __len__(self):
|
196 |
+
return len(self.metadata)
|
197 |
+
|
198 |
+
|
199 |
+
class CodecConcatDataset(ConcatDataset):
|
200 |
+
def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
|
201 |
+
"""Concatenate a series of datasets with their random inference audio merged."""
|
202 |
+
super().__init__(datasets)
|
203 |
+
|
204 |
+
self.cfg = self.datasets[0].cfg
|
205 |
+
|
206 |
+
self.metadata = []
|
207 |
+
|
208 |
+
# Merge metadata
|
209 |
+
for dataset in self.datasets:
|
210 |
+
self.metadata += dataset.metadata
|
211 |
+
|
212 |
+
# Merge random inference features
|
213 |
+
if full_audio_inference:
|
214 |
+
self.eval_audios = []
|
215 |
+
self.eval_dataset_names = []
|
216 |
+
if self.cfg.preprocess.use_mel:
|
217 |
+
self.eval_mels = []
|
218 |
+
if self.cfg.preprocess.use_frame_pitch:
|
219 |
+
self.eval_pitchs = []
|
220 |
+
for dataset in self.datasets:
|
221 |
+
self.eval_audios.append(dataset.eval_audio)
|
222 |
+
self.eval_dataset_names.append(dataset.get_dataset_name())
|
223 |
+
if self.cfg.preprocess.use_mel:
|
224 |
+
self.eval_mels.append(dataset.eval_mel)
|
225 |
+
if self.cfg.preprocess.use_frame_pitch:
|
226 |
+
self.eval_pitchs.append(dataset.eval_pitch)
|
227 |
+
|
228 |
+
|
229 |
+
class CodecCollator(object):
|
230 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
231 |
+
|
232 |
+
def __init__(self, cfg):
|
233 |
+
self.cfg = cfg
|
234 |
+
|
235 |
+
def __call__(self, batch):
|
236 |
+
packed_batch_features = dict()
|
237 |
+
|
238 |
+
# mel: [b, n_mels, frame]
|
239 |
+
# frame_pitch: [b, frame]
|
240 |
+
# audios: [b, frame * hop_size]
|
241 |
+
|
242 |
+
for key in batch[0].keys():
|
243 |
+
if key == "target_len":
|
244 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
245 |
+
[b["target_len"] for b in batch]
|
246 |
+
)
|
247 |
+
masks = [
|
248 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
249 |
+
]
|
250 |
+
packed_batch_features["mask"] = pad_sequence(
|
251 |
+
masks, batch_first=True, padding_value=0
|
252 |
+
)
|
253 |
+
elif key == "mel":
|
254 |
+
values = [torch.from_numpy(b[key]).T for b in batch]
|
255 |
+
packed_batch_features[key] = pad_sequence(
|
256 |
+
values, batch_first=True, padding_value=0
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
260 |
+
packed_batch_features[key] = pad_sequence(
|
261 |
+
values, batch_first=True, padding_value=0
|
262 |
+
)
|
263 |
+
|
264 |
+
return packed_batch_features
|
models/codec/codec_inference.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import json
|
9 |
+
import json5
|
10 |
+
import time
|
11 |
+
import accelerate
|
12 |
+
import random
|
13 |
+
import numpy as np
|
14 |
+
import shutil
|
15 |
+
|
16 |
+
from pathlib import Path
|
17 |
+
from tqdm import tqdm
|
18 |
+
from glob import glob
|
19 |
+
from accelerate.logging import get_logger
|
20 |
+
from torch.utils.data import DataLoader
|
21 |
+
|
22 |
+
from models.vocoders.vocoder_dataset import (
|
23 |
+
VocoderDataset,
|
24 |
+
VocoderCollator,
|
25 |
+
VocoderConcatDataset,
|
26 |
+
)
|
27 |
+
|
28 |
+
from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
|
29 |
+
from models.vocoders.flow.waveglow import waveglow
|
30 |
+
from models.vocoders.diffusion.diffwave import diffwave
|
31 |
+
from models.vocoders.autoregressive.wavenet import wavenet
|
32 |
+
from models.vocoders.autoregressive.wavernn import wavernn
|
33 |
+
|
34 |
+
from models.vocoders.gan import gan_vocoder_inference
|
35 |
+
from models.vocoders.diffusion import diffusion_vocoder_inference
|
36 |
+
|
37 |
+
from utils.io import save_audio
|
38 |
+
|
39 |
+
_vocoders = {
|
40 |
+
"diffwave": diffwave.DiffWave,
|
41 |
+
"wavernn": wavernn.WaveRNN,
|
42 |
+
"wavenet": wavenet.WaveNet,
|
43 |
+
"waveglow": waveglow.WaveGlow,
|
44 |
+
"nsfhifigan": nsfhifigan.NSFHiFiGAN,
|
45 |
+
"bigvgan": bigvgan.BigVGAN,
|
46 |
+
"hifigan": hifigan.HiFiGAN,
|
47 |
+
"melgan": melgan.MelGAN,
|
48 |
+
"apnet": apnet.APNet,
|
49 |
+
}
|
50 |
+
|
51 |
+
# Forward call for generalized Inferencor
|
52 |
+
_vocoder_forward_funcs = {
|
53 |
+
# "world": world_inference.synthesis_audios,
|
54 |
+
# "wavernn": wavernn_inference.synthesis_audios,
|
55 |
+
# "wavenet": wavenet_inference.synthesis_audios,
|
56 |
+
"diffwave": diffusion_vocoder_inference.vocoder_inference,
|
57 |
+
"nsfhifigan": gan_vocoder_inference.vocoder_inference,
|
58 |
+
"bigvgan": gan_vocoder_inference.vocoder_inference,
|
59 |
+
"melgan": gan_vocoder_inference.vocoder_inference,
|
60 |
+
"hifigan": gan_vocoder_inference.vocoder_inference,
|
61 |
+
"apnet": gan_vocoder_inference.vocoder_inference,
|
62 |
+
}
|
63 |
+
|
64 |
+
# APIs for other tasks. e.g. SVC, TTS, TTA...
|
65 |
+
_vocoder_infer_funcs = {
|
66 |
+
# "world": world_inference.synthesis_audios,
|
67 |
+
# "wavernn": wavernn_inference.synthesis_audios,
|
68 |
+
# "wavenet": wavenet_inference.synthesis_audios,
|
69 |
+
"diffwave": diffusion_vocoder_inference.synthesis_audios,
|
70 |
+
"nsfhifigan": gan_vocoder_inference.synthesis_audios,
|
71 |
+
"bigvgan": gan_vocoder_inference.synthesis_audios,
|
72 |
+
"melgan": gan_vocoder_inference.synthesis_audios,
|
73 |
+
"hifigan": gan_vocoder_inference.synthesis_audios,
|
74 |
+
"apnet": gan_vocoder_inference.synthesis_audios,
|
75 |
+
}
|
76 |
+
|
77 |
+
|
78 |
+
class VocoderInference(object):
|
79 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
start = time.monotonic_ns()
|
83 |
+
self.args = args
|
84 |
+
self.cfg = cfg
|
85 |
+
self.infer_type = infer_type
|
86 |
+
|
87 |
+
# Init accelerator
|
88 |
+
self.accelerator = accelerate.Accelerator()
|
89 |
+
self.accelerator.wait_for_everyone()
|
90 |
+
|
91 |
+
# Get logger
|
92 |
+
with self.accelerator.main_process_first():
|
93 |
+
self.logger = get_logger("inference", log_level=args.log_level)
|
94 |
+
|
95 |
+
# Log some info
|
96 |
+
self.logger.info("=" * 56)
|
97 |
+
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
|
98 |
+
self.logger.info("=" * 56)
|
99 |
+
self.logger.info("\n")
|
100 |
+
|
101 |
+
self.vocoder_dir = args.vocoder_dir
|
102 |
+
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
|
103 |
+
|
104 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
105 |
+
if os.path.exists(os.path.join(args.output_dir, "pred")):
|
106 |
+
shutil.rmtree(os.path.join(args.output_dir, "pred"))
|
107 |
+
if os.path.exists(os.path.join(args.output_dir, "gt")):
|
108 |
+
shutil.rmtree(os.path.join(args.output_dir, "gt"))
|
109 |
+
os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
|
110 |
+
os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
|
111 |
+
|
112 |
+
# Set random seed
|
113 |
+
with self.accelerator.main_process_first():
|
114 |
+
start = time.monotonic_ns()
|
115 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
116 |
+
end = time.monotonic_ns()
|
117 |
+
self.logger.debug(
|
118 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
119 |
+
)
|
120 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
121 |
+
|
122 |
+
# Setup inference mode
|
123 |
+
if self.infer_type == "infer_from_dataset":
|
124 |
+
self.cfg.dataset = self.args.infer_datasets
|
125 |
+
elif self.infer_type == "infer_from_feature":
|
126 |
+
self._build_tmp_dataset_from_feature()
|
127 |
+
self.cfg.dataset = ["tmp"]
|
128 |
+
elif self.infer_type == "infer_from_audio":
|
129 |
+
self._build_tmp_dataset_from_audio()
|
130 |
+
self.cfg.dataset = ["tmp"]
|
131 |
+
|
132 |
+
# Setup data loader
|
133 |
+
with self.accelerator.main_process_first():
|
134 |
+
self.logger.info("Building dataset...")
|
135 |
+
start = time.monotonic_ns()
|
136 |
+
self.test_dataloader = self._build_dataloader()
|
137 |
+
end = time.monotonic_ns()
|
138 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
139 |
+
|
140 |
+
# Build model
|
141 |
+
with self.accelerator.main_process_first():
|
142 |
+
self.logger.info("Building model...")
|
143 |
+
start = time.monotonic_ns()
|
144 |
+
self.model = self._build_model()
|
145 |
+
end = time.monotonic_ns()
|
146 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
|
147 |
+
|
148 |
+
# Init with accelerate
|
149 |
+
self.logger.info("Initializing accelerate...")
|
150 |
+
start = time.monotonic_ns()
|
151 |
+
self.accelerator = accelerate.Accelerator()
|
152 |
+
(self.model, self.test_dataloader) = self.accelerator.prepare(
|
153 |
+
self.model, self.test_dataloader
|
154 |
+
)
|
155 |
+
end = time.monotonic_ns()
|
156 |
+
self.accelerator.wait_for_everyone()
|
157 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
|
158 |
+
|
159 |
+
with self.accelerator.main_process_first():
|
160 |
+
self.logger.info("Loading checkpoint...")
|
161 |
+
start = time.monotonic_ns()
|
162 |
+
if os.path.isdir(args.vocoder_dir):
|
163 |
+
if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
|
164 |
+
self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
|
165 |
+
else:
|
166 |
+
self._load_model(os.path.join(args.vocoder_dir))
|
167 |
+
else:
|
168 |
+
self._load_model(os.path.join(args.vocoder_dir))
|
169 |
+
end = time.monotonic_ns()
|
170 |
+
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
|
171 |
+
|
172 |
+
self.model.eval()
|
173 |
+
self.accelerator.wait_for_everyone()
|
174 |
+
|
175 |
+
def _build_tmp_dataset_from_feature(self):
|
176 |
+
if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
|
177 |
+
shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
|
178 |
+
|
179 |
+
utts = []
|
180 |
+
mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
|
181 |
+
for i, mel in enumerate(mels):
|
182 |
+
uid = mel.split("/")[-1].split(".")[0]
|
183 |
+
utt = {"Dataset": "tmp", "Uid": uid, "index": i}
|
184 |
+
utts.append(utt)
|
185 |
+
|
186 |
+
os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
|
187 |
+
with open(
|
188 |
+
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
|
189 |
+
) as f:
|
190 |
+
json.dump(utts, f)
|
191 |
+
|
192 |
+
meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
|
193 |
+
|
194 |
+
with open(
|
195 |
+
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
|
196 |
+
"w",
|
197 |
+
) as f:
|
198 |
+
json.dump(meta_info, f)
|
199 |
+
|
200 |
+
features = glob(os.path.join(self.args.feature_folder, "*"))
|
201 |
+
for feature in features:
|
202 |
+
feature_name = feature.split("/")[-1]
|
203 |
+
if os.path.isfile(feature):
|
204 |
+
continue
|
205 |
+
shutil.copytree(
|
206 |
+
os.path.join(self.args.feature_folder, feature_name),
|
207 |
+
os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
|
208 |
+
)
|
209 |
+
|
210 |
+
def _build_tmp_dataset_from_audio(self):
|
211 |
+
if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
|
212 |
+
shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
|
213 |
+
|
214 |
+
utts = []
|
215 |
+
audios = glob(os.path.join(self.args.audio_folder, "*"))
|
216 |
+
for i, audio in enumerate(audios):
|
217 |
+
uid = audio.split("/")[-1].split(".")[0]
|
218 |
+
utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
|
219 |
+
utts.append(utt)
|
220 |
+
|
221 |
+
os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
|
222 |
+
with open(
|
223 |
+
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
|
224 |
+
) as f:
|
225 |
+
json.dump(utts, f)
|
226 |
+
|
227 |
+
meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
|
228 |
+
|
229 |
+
with open(
|
230 |
+
os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
|
231 |
+
"w",
|
232 |
+
) as f:
|
233 |
+
json.dump(meta_info, f)
|
234 |
+
|
235 |
+
from processors import acoustic_extractor
|
236 |
+
|
237 |
+
acoustic_extractor.extract_utt_acoustic_features_serial(
|
238 |
+
utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
|
239 |
+
)
|
240 |
+
|
241 |
+
def _build_test_dataset(self):
|
242 |
+
return VocoderDataset, VocoderCollator
|
243 |
+
|
244 |
+
def _build_model(self):
|
245 |
+
model = _vocoders[self.cfg.model.generator](self.cfg)
|
246 |
+
return model
|
247 |
+
|
248 |
+
def _build_dataloader(self):
|
249 |
+
"""Build dataloader which merges a series of datasets."""
|
250 |
+
Dataset, Collator = self._build_test_dataset()
|
251 |
+
|
252 |
+
datasets_list = []
|
253 |
+
for dataset in self.cfg.dataset:
|
254 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
255 |
+
datasets_list.append(subdataset)
|
256 |
+
test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
|
257 |
+
test_collate = Collator(self.cfg)
|
258 |
+
test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
|
259 |
+
test_dataloader = DataLoader(
|
260 |
+
test_dataset,
|
261 |
+
collate_fn=test_collate,
|
262 |
+
num_workers=1,
|
263 |
+
batch_size=test_batch_size,
|
264 |
+
shuffle=False,
|
265 |
+
)
|
266 |
+
self.test_batch_size = test_batch_size
|
267 |
+
self.test_dataset = test_dataset
|
268 |
+
return test_dataloader
|
269 |
+
|
270 |
+
def _load_model(self, checkpoint_dir, from_multi_gpu=False):
|
271 |
+
"""Load model from checkpoint. If a folder is given, it will
|
272 |
+
load the latest checkpoint in checkpoint_dir. If a path is given
|
273 |
+
it will load the checkpoint specified by checkpoint_path.
|
274 |
+
**Only use this method after** ``accelerator.prepare()``.
|
275 |
+
"""
|
276 |
+
if os.path.isdir(checkpoint_dir):
|
277 |
+
if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
|
278 |
+
checkpoint_path = checkpoint_dir
|
279 |
+
else:
|
280 |
+
# Load the latest accelerator state dicts
|
281 |
+
ls = [
|
282 |
+
str(i)
|
283 |
+
for i in Path(checkpoint_dir).glob("*")
|
284 |
+
if not "audio" in str(i)
|
285 |
+
]
|
286 |
+
ls.sort(
|
287 |
+
key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]),
|
288 |
+
reverse=True,
|
289 |
+
)
|
290 |
+
checkpoint_path = ls[0]
|
291 |
+
accelerate.load_checkpoint_and_dispatch(
|
292 |
+
self.accelerator.unwrap_model(self.model),
|
293 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
294 |
+
)
|
295 |
+
return str(checkpoint_path)
|
296 |
+
else:
|
297 |
+
# Load old .pt checkpoints
|
298 |
+
if self.cfg.model.generator in [
|
299 |
+
"bigvgan",
|
300 |
+
"hifigan",
|
301 |
+
"melgan",
|
302 |
+
"nsfhifigan",
|
303 |
+
]:
|
304 |
+
ckpt = torch.load(
|
305 |
+
checkpoint_dir,
|
306 |
+
map_location=(
|
307 |
+
torch.device("cuda")
|
308 |
+
if torch.cuda.is_available()
|
309 |
+
else torch.device("cpu")
|
310 |
+
),
|
311 |
+
)
|
312 |
+
if from_multi_gpu:
|
313 |
+
pretrained_generator_dict = ckpt["generator_state_dict"]
|
314 |
+
generator_dict = self.model.state_dict()
|
315 |
+
|
316 |
+
new_generator_dict = {
|
317 |
+
k.split("module.")[-1]: v
|
318 |
+
for k, v in pretrained_generator_dict.items()
|
319 |
+
if (
|
320 |
+
k.split("module.")[-1] in generator_dict
|
321 |
+
and v.shape == generator_dict[k.split("module.")[-1]].shape
|
322 |
+
)
|
323 |
+
}
|
324 |
+
|
325 |
+
generator_dict.update(new_generator_dict)
|
326 |
+
|
327 |
+
self.model.load_state_dict(generator_dict)
|
328 |
+
else:
|
329 |
+
self.model.load_state_dict(ckpt["generator_state_dict"])
|
330 |
+
else:
|
331 |
+
self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
|
332 |
+
return str(checkpoint_dir)
|
333 |
+
|
334 |
+
def inference(self):
|
335 |
+
"""Inference via batches"""
|
336 |
+
for i, batch in tqdm(enumerate(self.test_dataloader)):
|
337 |
+
if self.cfg.preprocess.use_frame_pitch:
|
338 |
+
audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
|
339 |
+
self.cfg,
|
340 |
+
self.model,
|
341 |
+
batch["mel"].transpose(-1, -2),
|
342 |
+
f0s=batch["frame_pitch"].float(),
|
343 |
+
device=next(self.model.parameters()).device,
|
344 |
+
)
|
345 |
+
else:
|
346 |
+
audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
|
347 |
+
self.cfg,
|
348 |
+
self.model,
|
349 |
+
batch["mel"].transpose(-1, -2),
|
350 |
+
device=next(self.model.parameters()).device,
|
351 |
+
)
|
352 |
+
audio_ls = audio_pred.chunk(self.test_batch_size)
|
353 |
+
audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
|
354 |
+
length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
|
355 |
+
j = 0
|
356 |
+
for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
|
357 |
+
l = l.item()
|
358 |
+
it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
|
359 |
+
it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
|
360 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
361 |
+
save_audio(
|
362 |
+
os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
|
363 |
+
it,
|
364 |
+
self.cfg.preprocess.sample_rate,
|
365 |
+
)
|
366 |
+
save_audio(
|
367 |
+
os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
|
368 |
+
it_gt,
|
369 |
+
self.cfg.preprocess.sample_rate,
|
370 |
+
)
|
371 |
+
j += 1
|
372 |
+
|
373 |
+
if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
|
374 |
+
shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
|
375 |
+
|
376 |
+
def _set_random_seed(self, seed):
|
377 |
+
"""Set random seed for all possible random modules."""
|
378 |
+
random.seed(seed)
|
379 |
+
np.random.seed(seed)
|
380 |
+
torch.random.manual_seed(seed)
|
381 |
+
|
382 |
+
def _count_parameters(self, model):
|
383 |
+
return sum(p.numel() for p in model.parameters())
|
384 |
+
|
385 |
+
def _dump_cfg(self, path):
|
386 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
387 |
+
json5.dump(
|
388 |
+
self.cfg,
|
389 |
+
open(path, "w"),
|
390 |
+
indent=4,
|
391 |
+
sort_keys=True,
|
392 |
+
ensure_ascii=False,
|
393 |
+
quote_keys=True,
|
394 |
+
)
|
395 |
+
|
396 |
+
|
397 |
+
def load_nnvocoder(
|
398 |
+
cfg,
|
399 |
+
vocoder_name,
|
400 |
+
weights_file,
|
401 |
+
from_multi_gpu=False,
|
402 |
+
):
|
403 |
+
"""Load the specified vocoder.
|
404 |
+
cfg: the vocoder config filer.
|
405 |
+
weights_file: a folder or a .pt path.
|
406 |
+
from_multi_gpu: automatically remove the "module" string in state dicts if "True".
|
407 |
+
"""
|
408 |
+
print("Loading Vocoder from Weights file: {}".format(weights_file))
|
409 |
+
|
410 |
+
# Build model
|
411 |
+
model = _vocoders[vocoder_name](cfg)
|
412 |
+
if not os.path.isdir(weights_file):
|
413 |
+
# Load from .pt file
|
414 |
+
if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
|
415 |
+
ckpt = torch.load(
|
416 |
+
weights_file,
|
417 |
+
map_location=(
|
418 |
+
torch.device("cuda")
|
419 |
+
if torch.cuda.is_available()
|
420 |
+
else torch.device("cpu")
|
421 |
+
),
|
422 |
+
)
|
423 |
+
if from_multi_gpu:
|
424 |
+
pretrained_generator_dict = ckpt["generator_state_dict"]
|
425 |
+
generator_dict = model.state_dict()
|
426 |
+
|
427 |
+
new_generator_dict = {
|
428 |
+
k.split("module.")[-1]: v
|
429 |
+
for k, v in pretrained_generator_dict.items()
|
430 |
+
if (
|
431 |
+
k.split("module.")[-1] in generator_dict
|
432 |
+
and v.shape == generator_dict[k.split("module.")[-1]].shape
|
433 |
+
)
|
434 |
+
}
|
435 |
+
|
436 |
+
generator_dict.update(new_generator_dict)
|
437 |
+
|
438 |
+
model.load_state_dict(generator_dict)
|
439 |
+
else:
|
440 |
+
model.load_state_dict(ckpt["generator_state_dict"])
|
441 |
+
else:
|
442 |
+
model.load_state_dict(torch.load(weights_file)["state_dict"])
|
443 |
+
else:
|
444 |
+
# Load from accelerator state dict
|
445 |
+
weights_file = os.path.join(weights_file, "checkpoint")
|
446 |
+
ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
|
447 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
448 |
+
checkpoint_path = ls[0]
|
449 |
+
accelerator = accelerate.Accelerator()
|
450 |
+
model = accelerator.prepare(model)
|
451 |
+
accelerator.load_state(checkpoint_path)
|
452 |
+
|
453 |
+
if torch.cuda.is_available():
|
454 |
+
model = model.cuda()
|
455 |
+
|
456 |
+
model = model.eval()
|
457 |
+
return model
|
458 |
+
|
459 |
+
|
460 |
+
def tensorize(data, device, n_samples):
|
461 |
+
"""
|
462 |
+
data: a list of numpy array
|
463 |
+
"""
|
464 |
+
assert type(data) == list
|
465 |
+
if n_samples:
|
466 |
+
data = data[:n_samples]
|
467 |
+
data = [torch.as_tensor(x, device=device) for x in data]
|
468 |
+
return data
|
469 |
+
|
470 |
+
|
471 |
+
def synthesis(
|
472 |
+
cfg,
|
473 |
+
vocoder_weight_file,
|
474 |
+
n_samples,
|
475 |
+
pred,
|
476 |
+
f0s=None,
|
477 |
+
batch_size=64,
|
478 |
+
fast_inference=False,
|
479 |
+
):
|
480 |
+
"""Synthesis audios from a given vocoder and series of given features.
|
481 |
+
cfg: vocoder config.
|
482 |
+
vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
|
483 |
+
pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
|
484 |
+
"""
|
485 |
+
|
486 |
+
vocoder_name = cfg.model.generator
|
487 |
+
|
488 |
+
print("Synthesis audios using {} vocoder...".format(vocoder_name))
|
489 |
+
|
490 |
+
###### TODO: World Vocoder Refactor ######
|
491 |
+
# if vocoder_name == "world":
|
492 |
+
# world_inference.synthesis_audios(
|
493 |
+
# cfg, dataset_name, split, n_samples, pred, save_dir, tag
|
494 |
+
# )
|
495 |
+
# return
|
496 |
+
|
497 |
+
# ====== Loading neural vocoder model ======
|
498 |
+
vocoder = load_nnvocoder(
|
499 |
+
cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
|
500 |
+
)
|
501 |
+
device = next(vocoder.parameters()).device
|
502 |
+
|
503 |
+
# ====== Inference for predicted acoustic features ======
|
504 |
+
# pred: (frame_len, n_mels) -> (n_mels, frame_len)
|
505 |
+
mels_pred = tensorize([p.T for p in pred], device, n_samples)
|
506 |
+
print("For predicted mels, #sample = {}...".format(len(mels_pred)))
|
507 |
+
audios_pred = _vocoder_infer_funcs[vocoder_name](
|
508 |
+
cfg,
|
509 |
+
vocoder,
|
510 |
+
mels_pred,
|
511 |
+
f0s=f0s,
|
512 |
+
batch_size=batch_size,
|
513 |
+
fast_inference=fast_inference,
|
514 |
+
)
|
515 |
+
return audios_pred
|
models/codec/codec_sampler.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
|
9 |
+
from torch.utils.data import ConcatDataset, Dataset
|
10 |
+
from torch.utils.data.sampler import (
|
11 |
+
BatchSampler,
|
12 |
+
RandomSampler,
|
13 |
+
Sampler,
|
14 |
+
SequentialSampler,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class ScheduledSampler(Sampler):
|
19 |
+
"""A sampler that samples data from a given concat-dataset.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
|
23 |
+
batch_size (int): batch size
|
24 |
+
holistic_shuffle (bool): whether to shuffle the whole dataset or not
|
25 |
+
logger (logging.Logger): logger to print warning message
|
26 |
+
|
27 |
+
Usage:
|
28 |
+
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
|
29 |
+
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
|
30 |
+
[3, 4, 5, 0, 1, 2, 6, 7, 8]
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
|
35 |
+
):
|
36 |
+
if not isinstance(concat_dataset, ConcatDataset):
|
37 |
+
raise ValueError(
|
38 |
+
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
|
39 |
+
type(concat_dataset)
|
40 |
+
)
|
41 |
+
)
|
42 |
+
if not isinstance(batch_size, int):
|
43 |
+
raise ValueError(
|
44 |
+
"batch_size must be an integer, but got {}".format(type(batch_size))
|
45 |
+
)
|
46 |
+
if not isinstance(holistic_shuffle, bool):
|
47 |
+
raise ValueError(
|
48 |
+
"holistic_shuffle must be a boolean, but got {}".format(
|
49 |
+
type(holistic_shuffle)
|
50 |
+
)
|
51 |
+
)
|
52 |
+
|
53 |
+
self.concat_dataset = concat_dataset
|
54 |
+
self.batch_size = batch_size
|
55 |
+
self.holistic_shuffle = holistic_shuffle
|
56 |
+
|
57 |
+
affected_dataset_name = []
|
58 |
+
affected_dataset_len = []
|
59 |
+
for dataset in concat_dataset.datasets:
|
60 |
+
dataset_len = len(dataset)
|
61 |
+
dataset_name = dataset.get_dataset_name()
|
62 |
+
if dataset_len < batch_size:
|
63 |
+
affected_dataset_name.append(dataset_name)
|
64 |
+
affected_dataset_len.append(dataset_len)
|
65 |
+
|
66 |
+
self.type = type
|
67 |
+
for dataset_name, dataset_len in zip(
|
68 |
+
affected_dataset_name, affected_dataset_len
|
69 |
+
):
|
70 |
+
if not type == "valid":
|
71 |
+
logger.warning(
|
72 |
+
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
|
73 |
+
type, dataset_name, dataset_len, batch_size
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
# the number of batches with drop last
|
79 |
+
num_of_batches = sum(
|
80 |
+
[
|
81 |
+
math.floor(len(dataset) / self.batch_size)
|
82 |
+
for dataset in self.concat_dataset.datasets
|
83 |
+
]
|
84 |
+
)
|
85 |
+
return num_of_batches * self.batch_size
|
86 |
+
|
87 |
+
def __iter__(self):
|
88 |
+
iters = []
|
89 |
+
for dataset in self.concat_dataset.datasets:
|
90 |
+
iters.append(
|
91 |
+
SequentialSampler(dataset).__iter__()
|
92 |
+
if self.holistic_shuffle
|
93 |
+
else RandomSampler(dataset).__iter__()
|
94 |
+
)
|
95 |
+
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
|
96 |
+
output_batches = []
|
97 |
+
for dataset_idx in range(len(self.concat_dataset.datasets)):
|
98 |
+
cur_batch = []
|
99 |
+
for idx in iters[dataset_idx]:
|
100 |
+
cur_batch.append(idx + init_indices[dataset_idx])
|
101 |
+
if len(cur_batch) == self.batch_size:
|
102 |
+
output_batches.append(cur_batch)
|
103 |
+
cur_batch = []
|
104 |
+
if self.type == "valid" and len(cur_batch) > 0:
|
105 |
+
output_batches.append(cur_batch)
|
106 |
+
cur_batch = []
|
107 |
+
# force drop last in training
|
108 |
+
random.shuffle(output_batches)
|
109 |
+
output_indices = [item for sublist in output_batches for item in sublist]
|
110 |
+
return iter(output_indices)
|
111 |
+
|
112 |
+
|
113 |
+
def build_samplers(concat_dataset: Dataset, cfg, logger, type):
|
114 |
+
sampler = ScheduledSampler(
|
115 |
+
concat_dataset,
|
116 |
+
cfg.train.batch_size,
|
117 |
+
cfg.train.sampler.holistic_shuffle,
|
118 |
+
logger,
|
119 |
+
type,
|
120 |
+
)
|
121 |
+
batch_sampler = BatchSampler(
|
122 |
+
sampler,
|
123 |
+
cfg.train.batch_size,
|
124 |
+
cfg.train.sampler.drop_last if not type == "valid" else False,
|
125 |
+
)
|
126 |
+
return sampler, batch_sampler
|
models/codec/codec_trainer.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from pathlib import Path
|
9 |
+
import re
|
10 |
+
|
11 |
+
import accelerate
|
12 |
+
import json5
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from accelerate.utils import ProjectConfiguration
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from models.codec.codec_sampler import build_samplers
|
20 |
+
|
21 |
+
|
22 |
+
class CodecTrainer:
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
def _init_accelerator(self):
|
27 |
+
"""Initialize the accelerator components."""
|
28 |
+
self.exp_dir = os.path.join(
|
29 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
30 |
+
)
|
31 |
+
project_config = ProjectConfiguration(
|
32 |
+
project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log")
|
33 |
+
)
|
34 |
+
self.accelerator = accelerate.Accelerator(
|
35 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
36 |
+
log_with=self.cfg.train.tracker,
|
37 |
+
project_config=project_config,
|
38 |
+
)
|
39 |
+
if self.accelerator.is_main_process:
|
40 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
41 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
42 |
+
with self.accelerator.main_process_first():
|
43 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
44 |
+
|
45 |
+
def _build_dataset(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
def _build_criterion(self):
|
49 |
+
pass
|
50 |
+
|
51 |
+
def _build_model(self):
|
52 |
+
pass
|
53 |
+
|
54 |
+
def _build_dataloader(self):
|
55 |
+
"""Build dataloader which merges a series of datasets."""
|
56 |
+
# Build dataset instance for each dataset and combine them by ConcatDataset
|
57 |
+
Dataset, Collator = self._build_dataset()
|
58 |
+
|
59 |
+
# Build train set
|
60 |
+
train_dataset = Dataset(self.cfg, self.cfg.dataset, is_valid=False)
|
61 |
+
train_collate = Collator(self.cfg)
|
62 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
63 |
+
train_dataset,
|
64 |
+
num_replicas=self.accelerator.num_processes,
|
65 |
+
rank=self.accelerator.local_process_index,
|
66 |
+
shuffle=True,
|
67 |
+
seed=self.cfg.train.random_seed,
|
68 |
+
)
|
69 |
+
train_loader = DataLoader(
|
70 |
+
train_dataset,
|
71 |
+
batch_size=self.cfg.train.batch_size,
|
72 |
+
collate_fn=train_collate,
|
73 |
+
sampler=sampler,
|
74 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
75 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
76 |
+
)
|
77 |
+
return train_loader, None
|
78 |
+
|
79 |
+
def _build_optimizer(self):
|
80 |
+
pass
|
81 |
+
|
82 |
+
def _build_scheduler(self):
|
83 |
+
pass
|
84 |
+
|
85 |
+
def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
|
86 |
+
"""Load model from checkpoint. If a folder is given, it will
|
87 |
+
load the latest checkpoint in checkpoint_dir. If a path is given
|
88 |
+
it will load the checkpoint specified by checkpoint_path.
|
89 |
+
**Only use this method after** ``accelerator.prepare()``.
|
90 |
+
"""
|
91 |
+
if checkpoint_path is None:
|
92 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
93 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
94 |
+
checkpoint_path = ls[0]
|
95 |
+
if resume_type == "resume":
|
96 |
+
self.accelerator.load_state(checkpoint_path)
|
97 |
+
elif resume_type == "finetune":
|
98 |
+
accelerate.load_checkpoint_and_dispatch(
|
99 |
+
self.accelerator.unwrap_model(self.model),
|
100 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
101 |
+
)
|
102 |
+
self.logger.info("Load model weights for finetune SUCCESS!")
|
103 |
+
else:
|
104 |
+
raise ValueError("Unsupported resume type: {}".format(resume_type))
|
105 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
106 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
107 |
+
return checkpoint_path
|
108 |
+
|
109 |
+
def train_loop(self):
|
110 |
+
pass
|
111 |
+
|
112 |
+
def _train_epoch(self):
|
113 |
+
pass
|
114 |
+
|
115 |
+
def _valid_epoch(self):
|
116 |
+
pass
|
117 |
+
|
118 |
+
def _train_step(self):
|
119 |
+
pass
|
120 |
+
|
121 |
+
def _valid_step(self):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def _inference(self):
|
125 |
+
pass
|
126 |
+
|
127 |
+
def _set_random_seed(self, seed):
|
128 |
+
"""Set random seed for all possible random modules."""
|
129 |
+
random.seed(seed)
|
130 |
+
np.random.seed(seed)
|
131 |
+
torch.random.manual_seed(seed)
|
132 |
+
|
133 |
+
def _check_nan(self, loss):
|
134 |
+
if torch.any(torch.isnan(loss)):
|
135 |
+
self.logger.fatal("Fatal Error: NaN!")
|
136 |
+
self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
|
137 |
+
|
138 |
+
def _check_basic_configs(self):
|
139 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
140 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
141 |
+
self.logger.error(
|
142 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
143 |
+
)
|
144 |
+
self.accelerator.end_training()
|
145 |
+
raise ValueError(
|
146 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
147 |
+
)
|
148 |
+
|
149 |
+
def _count_parameters(self):
|
150 |
+
pass
|
151 |
+
|
152 |
+
def _dump_cfg(self, path):
|
153 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
154 |
+
json5.dump(
|
155 |
+
self.cfg,
|
156 |
+
open(path, "w"),
|
157 |
+
indent=4,
|
158 |
+
sort_keys=True,
|
159 |
+
ensure_ascii=False,
|
160 |
+
quote_keys=True,
|
161 |
+
)
|
162 |
+
|
163 |
+
def _is_valid_pattern(self, directory_name):
|
164 |
+
directory_name = str(directory_name)
|
165 |
+
pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
|
166 |
+
return re.match(pattern, directory_name) is not None
|
models/codec/facodec/__init__.py
ADDED
File without changes
|
models/codec/facodec/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
from .filter import *
|
4 |
+
from .resample import *
|
5 |
+
from .act import *
|
models/codec/facodec/alias_free_torch/act.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from .resample import UpSample1d, DownSample1d
|
5 |
+
|
6 |
+
|
7 |
+
class Activation1d(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
activation,
|
11 |
+
up_ratio: int = 2,
|
12 |
+
down_ratio: int = 2,
|
13 |
+
up_kernel_size: int = 12,
|
14 |
+
down_kernel_size: int = 12,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.up_ratio = up_ratio
|
18 |
+
self.down_ratio = down_ratio
|
19 |
+
self.act = activation
|
20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
22 |
+
|
23 |
+
# x: [B,C,T]
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.upsample(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.downsample(x)
|
28 |
+
|
29 |
+
return x
|
models/codec/facodec/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
|
8 |
+
if "sinc" in dir(torch):
|
9 |
+
sinc = torch.sinc
|
10 |
+
else:
|
11 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
12 |
+
# https://adefossez.github.io/julius/julius/core.html
|
13 |
+
def sinc(x: torch.Tensor):
|
14 |
+
"""
|
15 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
16 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
17 |
+
"""
|
18 |
+
return torch.where(
|
19 |
+
x == 0,
|
20 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
21 |
+
torch.sin(math.pi * x) / math.pi / x,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
+
def kaiser_sinc_filter1d(
|
28 |
+
cutoff, half_width, kernel_size
|
29 |
+
): # return filter [1,1,kernel_size]
|
30 |
+
even = kernel_size % 2 == 0
|
31 |
+
half_size = kernel_size // 2
|
32 |
+
|
33 |
+
# For kaiser window
|
34 |
+
delta_f = 4 * half_width
|
35 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
36 |
+
if A > 50.0:
|
37 |
+
beta = 0.1102 * (A - 8.7)
|
38 |
+
elif A >= 21.0:
|
39 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
40 |
+
else:
|
41 |
+
beta = 0.0
|
42 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
43 |
+
|
44 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
45 |
+
if even:
|
46 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
47 |
+
else:
|
48 |
+
time = torch.arange(kernel_size) - half_size
|
49 |
+
if cutoff == 0:
|
50 |
+
filter_ = torch.zeros_like(time)
|
51 |
+
else:
|
52 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
53 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
54 |
+
# of the constant component in the input signal.
|
55 |
+
filter_ /= filter_.sum()
|
56 |
+
filter = filter_.view(1, 1, kernel_size)
|
57 |
+
|
58 |
+
return filter
|
59 |
+
|
60 |
+
|
61 |
+
class LowPassFilter1d(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
cutoff=0.5,
|
65 |
+
half_width=0.6,
|
66 |
+
stride: int = 1,
|
67 |
+
padding: bool = True,
|
68 |
+
padding_mode: str = "replicate",
|
69 |
+
kernel_size: int = 12,
|
70 |
+
):
|
71 |
+
# kernel_size should be even number for stylegan3 setup,
|
72 |
+
# in this implementation, odd number is also possible.
|
73 |
+
super().__init__()
|
74 |
+
if cutoff < -0.0:
|
75 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
76 |
+
if cutoff > 0.5:
|
77 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
78 |
+
self.kernel_size = kernel_size
|
79 |
+
self.even = kernel_size % 2 == 0
|
80 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
81 |
+
self.pad_right = kernel_size // 2
|
82 |
+
self.stride = stride
|
83 |
+
self.padding = padding
|
84 |
+
self.padding_mode = padding_mode
|
85 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
86 |
+
self.register_buffer("filter", filter)
|
87 |
+
|
88 |
+
# input [B, C, T]
|
89 |
+
def forward(self, x):
|
90 |
+
_, C, _ = x.shape
|
91 |
+
|
92 |
+
if self.padding:
|
93 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
94 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
95 |
+
|
96 |
+
return out
|
models/codec/facodec/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from .filter import LowPassFilter1d
|
6 |
+
from .filter import kaiser_sinc_filter1d
|
7 |
+
|
8 |
+
|
9 |
+
class UpSample1d(nn.Module):
|
10 |
+
def __init__(self, ratio=2, kernel_size=None):
|
11 |
+
super().__init__()
|
12 |
+
self.ratio = ratio
|
13 |
+
self.kernel_size = (
|
14 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
)
|
16 |
+
self.stride = ratio
|
17 |
+
self.pad = self.kernel_size // ratio - 1
|
18 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
19 |
+
self.pad_right = (
|
20 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
21 |
+
)
|
22 |
+
filter = kaiser_sinc_filter1d(
|
23 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
24 |
+
)
|
25 |
+
self.register_buffer("filter", filter)
|
26 |
+
|
27 |
+
# x: [B, C, T]
|
28 |
+
def forward(self, x):
|
29 |
+
_, C, _ = x.shape
|
30 |
+
|
31 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
32 |
+
x = self.ratio * F.conv_transpose1d(
|
33 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
34 |
+
)
|
35 |
+
x = x[..., self.pad_left : -self.pad_right]
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class DownSample1d(nn.Module):
|
41 |
+
def __init__(self, ratio=2, kernel_size=None):
|
42 |
+
super().__init__()
|
43 |
+
self.ratio = ratio
|
44 |
+
self.kernel_size = (
|
45 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
46 |
+
)
|
47 |
+
self.lowpass = LowPassFilter1d(
|
48 |
+
cutoff=0.5 / ratio,
|
49 |
+
half_width=0.6 / ratio,
|
50 |
+
stride=ratio,
|
51 |
+
kernel_size=self.kernel_size,
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
xx = self.lowpass(x)
|
56 |
+
|
57 |
+
return xx
|
models/codec/facodec/facodec_dataset.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import random
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torchaudio
|
12 |
+
import librosa
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from torch.nn.utils.rnn import pad_sequence
|
16 |
+
from utils.data_utils import *
|
17 |
+
from models.codec.codec_dataset import CodecDataset
|
18 |
+
|
19 |
+
|
20 |
+
class FAcodecDataset(torch.utils.data.Dataset):
|
21 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
cfg: config
|
25 |
+
dataset: dataset name
|
26 |
+
is_valid: whether to use train or valid dataset
|
27 |
+
"""
|
28 |
+
self.data_root_dir = cfg.dataset
|
29 |
+
self.data_list = []
|
30 |
+
# walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a
|
31 |
+
for root, _, files in os.walk(self.data_root_dir):
|
32 |
+
for file in files:
|
33 |
+
if file.endswith((".wav", ".mp3", ".opus", ".flac", ".m4a")):
|
34 |
+
self.data_list.append(os.path.join(root, file))
|
35 |
+
self.sr = cfg.preprocess_params.sr
|
36 |
+
self.duration_range = cfg.preprocess_params.duration_range
|
37 |
+
self.to_mel = torchaudio.transforms.MelSpectrogram(
|
38 |
+
n_mels=cfg.preprocess_params.spect_params.n_mels,
|
39 |
+
n_fft=cfg.preprocess_params.spect_params.n_fft,
|
40 |
+
win_length=cfg.preprocess_params.spect_params.win_length,
|
41 |
+
hop_length=cfg.preprocess_params.spect_params.hop_length,
|
42 |
+
)
|
43 |
+
self.mean, self.std = -4, 4
|
44 |
+
|
45 |
+
def preprocess(self, wave):
|
46 |
+
wave_tensor = (
|
47 |
+
torch.from_numpy(wave).float() if isinstance(wave, np.ndarray) else wave
|
48 |
+
)
|
49 |
+
mel_tensor = self.to_mel(wave_tensor)
|
50 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std
|
51 |
+
return mel_tensor
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
# return len(self.data_list)
|
55 |
+
return len(self.data_list) # return a fixed number for testing
|
56 |
+
|
57 |
+
def __getitem__(self, index):
|
58 |
+
wave, _ = librosa.load(self.data_list[index], sr=self.sr)
|
59 |
+
wave = np.random.randn(self.sr * random.randint(*self.duration_range))
|
60 |
+
wave = wave / np.max(np.abs(wave))
|
61 |
+
mel = self.preprocess(wave).squeeze(0)
|
62 |
+
wave = torch.from_numpy(wave).float()
|
63 |
+
return wave, mel
|
64 |
+
|
65 |
+
|
66 |
+
class FAcodecCollator(object):
|
67 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
68 |
+
|
69 |
+
def __init__(self, cfg):
|
70 |
+
self.cfg = cfg
|
71 |
+
|
72 |
+
def __call__(self, batch):
|
73 |
+
# batch[0] = wave, mel, text, f0, speakerid
|
74 |
+
batch_size = len(batch)
|
75 |
+
|
76 |
+
# sort by mel length
|
77 |
+
lengths = [b[1].shape[1] for b in batch]
|
78 |
+
batch_indexes = np.argsort(lengths)[::-1]
|
79 |
+
batch = [batch[bid] for bid in batch_indexes]
|
80 |
+
|
81 |
+
nmels = batch[0][1].size(0)
|
82 |
+
max_mel_length = max([b[1].shape[1] for b in batch])
|
83 |
+
max_wave_length = max([b[0].size(0) for b in batch])
|
84 |
+
|
85 |
+
mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
|
86 |
+
waves = torch.zeros((batch_size, max_wave_length)).float()
|
87 |
+
|
88 |
+
mel_lengths = torch.zeros(batch_size).long()
|
89 |
+
wave_lengths = torch.zeros(batch_size).long()
|
90 |
+
|
91 |
+
for bid, (wave, mel) in enumerate(batch):
|
92 |
+
mel_size = mel.size(1)
|
93 |
+
mels[bid, :, :mel_size] = mel
|
94 |
+
waves[bid, : wave.size(0)] = wave
|
95 |
+
mel_lengths[bid] = mel_size
|
96 |
+
wave_lengths[bid] = wave.size(0)
|
97 |
+
|
98 |
+
return waves, mels, wave_lengths, mel_lengths
|
models/codec/facodec/facodec_inference.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import shutil
|
7 |
+
import warnings
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import os
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
warnings.simplefilter("ignore")
|
14 |
+
|
15 |
+
from .modules.commons import *
|
16 |
+
import time
|
17 |
+
|
18 |
+
import torchaudio
|
19 |
+
import librosa
|
20 |
+
from collections import OrderedDict
|
21 |
+
|
22 |
+
|
23 |
+
class FAcodecInference(object):
|
24 |
+
def __init__(self, args=None, cfg=None):
|
25 |
+
self.args = args
|
26 |
+
self.cfg = cfg
|
27 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
self.model = self._build_model()
|
29 |
+
self._load_checkpoint()
|
30 |
+
|
31 |
+
def _build_model(self):
|
32 |
+
model = build_model(self.cfg.model_params)
|
33 |
+
_ = [model[key].to(self.device) for key in model]
|
34 |
+
return model
|
35 |
+
|
36 |
+
def _load_checkpoint(self):
|
37 |
+
sd = torch.load(self.args.checkpoint_path, map_location="cpu")
|
38 |
+
sd = sd["net"] if "net" in sd else sd
|
39 |
+
new_params = dict()
|
40 |
+
for key, state_dict in sd.items():
|
41 |
+
new_state_dict = OrderedDict()
|
42 |
+
for k, v in state_dict.items():
|
43 |
+
if k.startswith("module."):
|
44 |
+
k = k[7:]
|
45 |
+
new_state_dict[k] = v
|
46 |
+
new_params[key] = new_state_dict
|
47 |
+
for key in new_params:
|
48 |
+
if key in self.model:
|
49 |
+
self.model[key].load_state_dict(new_params[key])
|
50 |
+
_ = [self.model[key].eval() for key in self.model]
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def inference(self, source, output_dir):
|
54 |
+
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
|
55 |
+
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
|
56 |
+
|
57 |
+
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
|
58 |
+
(
|
59 |
+
z,
|
60 |
+
quantized,
|
61 |
+
commitment_loss,
|
62 |
+
codebook_loss,
|
63 |
+
timbre,
|
64 |
+
codes,
|
65 |
+
) = self.model.quantizer(
|
66 |
+
z,
|
67 |
+
source_audio[None, ...].to(self.device).float(),
|
68 |
+
n_c=self.cfg.model_params.n_c_codebooks,
|
69 |
+
return_codes=True,
|
70 |
+
)
|
71 |
+
|
72 |
+
full_pred_wave = self.model.decoder(z)
|
73 |
+
|
74 |
+
os.makedirs(output_dir, exist_ok=True)
|
75 |
+
source_name = source.split("/")[-1].split(".")[0]
|
76 |
+
torchaudio.save(
|
77 |
+
f"{output_dir}/reconstructed_{source_name}.wav",
|
78 |
+
full_pred_wave[0].cpu(),
|
79 |
+
self.cfg.preprocess_params.sr,
|
80 |
+
)
|
81 |
+
|
82 |
+
print(
|
83 |
+
"Reconstructed audio saved as: ",
|
84 |
+
f"{output_dir}/reconstructed_{source_name}.wav",
|
85 |
+
)
|
86 |
+
|
87 |
+
return quantized, codes
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def voice_conversion(self, source, reference, output_dir):
|
91 |
+
source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
|
92 |
+
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
|
93 |
+
|
94 |
+
reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0]
|
95 |
+
reference_audio = (
|
96 |
+
torch.tensor(reference_audio).unsqueeze(0).float().to(self.device)
|
97 |
+
)
|
98 |
+
|
99 |
+
z = self.model.encoder(source_audio[None, ...].to(self.device).float())
|
100 |
+
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
|
101 |
+
z,
|
102 |
+
source_audio[None, ...].to(self.device).float(),
|
103 |
+
n_c=self.cfg.model_params.n_c_codebooks,
|
104 |
+
)
|
105 |
+
|
106 |
+
z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float())
|
107 |
+
(
|
108 |
+
z_ref,
|
109 |
+
quantized_ref,
|
110 |
+
commitment_loss_ref,
|
111 |
+
codebook_loss_ref,
|
112 |
+
timbre_ref,
|
113 |
+
) = self.model.quantizer(
|
114 |
+
z_ref,
|
115 |
+
reference_audio[None, ...].to(self.device).float(),
|
116 |
+
n_c=self.cfg.model_params.n_c_codebooks,
|
117 |
+
)
|
118 |
+
|
119 |
+
z_conv = self.model.quantizer.voice_conversion(
|
120 |
+
quantized[0] + quantized[1],
|
121 |
+
reference_audio[None, ...].to(self.device).float(),
|
122 |
+
)
|
123 |
+
full_pred_wave = self.model.decoder(z_conv)
|
124 |
+
|
125 |
+
os.makedirs(output_dir, exist_ok=True)
|
126 |
+
source_name = source.split("/")[-1].split(".")[0]
|
127 |
+
reference_name = reference.split("/")[-1].split(".")[0]
|
128 |
+
torchaudio.save(
|
129 |
+
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
|
130 |
+
full_pred_wave[0].cpu(),
|
131 |
+
self.cfg.preprocess_params.sr,
|
132 |
+
)
|
133 |
+
|
134 |
+
print(
|
135 |
+
"Voice conversion results saved as: ",
|
136 |
+
f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
|
137 |
+
)
|
models/codec/facodec/facodec_trainer.py
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import random
|
9 |
+
from pathlib import Path
|
10 |
+
import re
|
11 |
+
import glob
|
12 |
+
|
13 |
+
import accelerate
|
14 |
+
import json
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from accelerate.utils import ProjectConfiguration
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torchaudio
|
24 |
+
|
25 |
+
from accelerate.logging import get_logger
|
26 |
+
|
27 |
+
from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator
|
28 |
+
from models.codec.codec_sampler import build_samplers
|
29 |
+
from models.codec.codec_trainer import CodecTrainer
|
30 |
+
|
31 |
+
from modules.dac.nn.loss import (
|
32 |
+
MultiScaleSTFTLoss,
|
33 |
+
MelSpectrogramLoss,
|
34 |
+
GANLoss,
|
35 |
+
L1Loss,
|
36 |
+
FocalLoss,
|
37 |
+
)
|
38 |
+
from audiotools import AudioSignal
|
39 |
+
|
40 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
41 |
+
|
42 |
+
try:
|
43 |
+
import nemo.collections.asr as nemo_asr
|
44 |
+
except ImportError:
|
45 |
+
print(
|
46 |
+
"Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING"
|
47 |
+
)
|
48 |
+
nemo_asr = None
|
49 |
+
|
50 |
+
from models.codec.facodec.modules.commons import (
|
51 |
+
build_model,
|
52 |
+
load_checkpoint,
|
53 |
+
load_F0_models,
|
54 |
+
log_norm,
|
55 |
+
)
|
56 |
+
from models.codec.facodec.optimizer import build_optimizer
|
57 |
+
|
58 |
+
|
59 |
+
class FAcodecTrainer(CodecTrainer):
|
60 |
+
def __init__(self, args, cfg):
|
61 |
+
super().__init__()
|
62 |
+
|
63 |
+
self.args = args
|
64 |
+
self.cfg = cfg
|
65 |
+
|
66 |
+
cfg.exp_name = args.exp_name
|
67 |
+
|
68 |
+
# Init accelerator
|
69 |
+
self._init_accelerator()
|
70 |
+
self.accelerator.wait_for_everyone()
|
71 |
+
|
72 |
+
# Init logger
|
73 |
+
with self.accelerator.main_process_first():
|
74 |
+
self.logger = get_logger(args.exp_name, log_level=args.log_level)
|
75 |
+
|
76 |
+
self.logger.info("=" * 56)
|
77 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
78 |
+
self.logger.info("=" * 56)
|
79 |
+
self.logger.info("\n")
|
80 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
81 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
82 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
83 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
84 |
+
if self.accelerator.is_main_process:
|
85 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
86 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
87 |
+
|
88 |
+
# Init training status
|
89 |
+
self.batch_count: int = 0
|
90 |
+
self.step: int = 0
|
91 |
+
self.epoch: int = 0
|
92 |
+
|
93 |
+
self.max_epoch = (
|
94 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
95 |
+
)
|
96 |
+
self.logger.info(
|
97 |
+
"Max epoch: {}".format(
|
98 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
# Check potential erorrs
|
103 |
+
if self.accelerator.is_main_process:
|
104 |
+
self._check_basic_configs()
|
105 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
106 |
+
self.checkpoints_path = [
|
107 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
108 |
+
]
|
109 |
+
self.run_eval = self.cfg.train.run_eval
|
110 |
+
|
111 |
+
# Set random seed
|
112 |
+
with self.accelerator.main_process_first():
|
113 |
+
start = time.monotonic_ns()
|
114 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
115 |
+
end = time.monotonic_ns()
|
116 |
+
self.logger.debug(
|
117 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
118 |
+
)
|
119 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
120 |
+
|
121 |
+
# Build dataloader
|
122 |
+
with self.accelerator.main_process_first():
|
123 |
+
self.logger.info("Building dataset...")
|
124 |
+
start = time.monotonic_ns()
|
125 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
126 |
+
end = time.monotonic_ns()
|
127 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
128 |
+
|
129 |
+
# Build model
|
130 |
+
with self.accelerator.main_process_first():
|
131 |
+
self.logger.info("Building model...")
|
132 |
+
start = time.monotonic_ns()
|
133 |
+
self.model = self._build_model()
|
134 |
+
end = time.monotonic_ns()
|
135 |
+
for _, model in self.model.items():
|
136 |
+
self.logger.debug(model)
|
137 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
138 |
+
self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
|
139 |
+
|
140 |
+
# Build optimizers and schedulers
|
141 |
+
with self.accelerator.main_process_first():
|
142 |
+
self.logger.info("Building optimizer and scheduler...")
|
143 |
+
start = time.monotonic_ns()
|
144 |
+
self.optimizer = self._build_optimizer()
|
145 |
+
end = time.monotonic_ns()
|
146 |
+
self.logger.info(
|
147 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
148 |
+
)
|
149 |
+
|
150 |
+
# Build helper models
|
151 |
+
with self.accelerator.main_process_first():
|
152 |
+
self.logger.info("Building helper models...")
|
153 |
+
start = time.monotonic_ns()
|
154 |
+
self._built_helper_model()
|
155 |
+
end = time.monotonic_ns()
|
156 |
+
self.logger.info(
|
157 |
+
f"Building helper models done in {(end - start) / 1e6:.2f}ms"
|
158 |
+
)
|
159 |
+
|
160 |
+
# Accelerator preparing
|
161 |
+
self.logger.info("Initializing accelerate...")
|
162 |
+
start = time.monotonic_ns()
|
163 |
+
for k in self.model:
|
164 |
+
self.model[k] = self.accelerator.prepare(self.model[k])
|
165 |
+
for k, v in self.optimizer.optimizers.items():
|
166 |
+
self.optimizer.optimizers[k] = self.accelerator.prepare(
|
167 |
+
self.optimizer.optimizers[k]
|
168 |
+
)
|
169 |
+
self.optimizer.schedulers[k] = self.accelerator.prepare(
|
170 |
+
self.optimizer.schedulers[k]
|
171 |
+
)
|
172 |
+
end = time.monotonic_ns()
|
173 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
|
174 |
+
|
175 |
+
# Build criterions
|
176 |
+
with self.accelerator.main_process_first():
|
177 |
+
self.logger.info("Building criterion...")
|
178 |
+
start = time.monotonic_ns()
|
179 |
+
self.criterions = self._build_criterion()
|
180 |
+
end = time.monotonic_ns()
|
181 |
+
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
|
182 |
+
|
183 |
+
# Resume checkpoints
|
184 |
+
with self.accelerator.main_process_first():
|
185 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
186 |
+
if args.resume_type:
|
187 |
+
self.logger.info("Resuming from checkpoint...")
|
188 |
+
start = time.monotonic_ns()
|
189 |
+
ckpt_path = Path(args.checkpoint)
|
190 |
+
if self._is_valid_pattern(ckpt_path.parts[-1]):
|
191 |
+
ckpt_path = self._load_model(args.checkpoint, args.resume_type)
|
192 |
+
else:
|
193 |
+
ckpt_path = self._load_model(
|
194 |
+
args.checkpoint, resume_type=args.resume_type
|
195 |
+
)
|
196 |
+
end = time.monotonic_ns()
|
197 |
+
self.logger.info(
|
198 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
199 |
+
)
|
200 |
+
self.checkpoints_path = json.load(
|
201 |
+
open(os.path.join(ckpt_path, "ckpts.json"), "r")
|
202 |
+
)
|
203 |
+
|
204 |
+
if self.accelerator.is_main_process:
|
205 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
206 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
207 |
+
|
208 |
+
# Save config
|
209 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
210 |
+
|
211 |
+
def _build_dataset(self):
|
212 |
+
return FAcodecDataset, FAcodecCollator
|
213 |
+
|
214 |
+
def _build_criterion(self):
|
215 |
+
criterions = dict()
|
216 |
+
stft_criterion = MultiScaleSTFTLoss()
|
217 |
+
mel_criterion = MelSpectrogramLoss(
|
218 |
+
n_mels=[5, 10, 20, 40, 80, 160, 320],
|
219 |
+
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
|
220 |
+
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
|
221 |
+
mel_fmax=[None, None, None, None, None, None, None],
|
222 |
+
pow=1.0,
|
223 |
+
mag_weight=0.0,
|
224 |
+
clamp_eps=1e-5,
|
225 |
+
)
|
226 |
+
content_criterion = FocalLoss(gamma=2)
|
227 |
+
l1_criterion = L1Loss()
|
228 |
+
criterions["stft"] = stft_criterion
|
229 |
+
criterions["mel"] = mel_criterion
|
230 |
+
criterions["l1"] = l1_criterion
|
231 |
+
criterions["content"] = content_criterion
|
232 |
+
|
233 |
+
return criterions
|
234 |
+
|
235 |
+
def _build_model(self):
|
236 |
+
model = build_model(self.cfg.model_params)
|
237 |
+
_ = [model[key].to(self.accelerator.device) for key in model]
|
238 |
+
return model
|
239 |
+
|
240 |
+
def _built_helper_model(self):
|
241 |
+
device = self.accelerator.device
|
242 |
+
self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device)
|
243 |
+
|
244 |
+
# load model and processor
|
245 |
+
self.w2v_processor = Wav2Vec2Processor.from_pretrained(
|
246 |
+
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
247 |
+
)
|
248 |
+
self.w2v_model = Wav2Vec2ForCTC.from_pretrained(
|
249 |
+
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
250 |
+
).to(device)
|
251 |
+
self.w2v_model.eval()
|
252 |
+
|
253 |
+
if nemo_asr is None:
|
254 |
+
self.speaker_model = None
|
255 |
+
else:
|
256 |
+
self.speaker_model = (
|
257 |
+
nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
|
258 |
+
"nvidia/speakerverification_en_titanet_large"
|
259 |
+
)
|
260 |
+
)
|
261 |
+
self.speaker_model = self.speaker_model.to(device)
|
262 |
+
self.speaker_model.eval()
|
263 |
+
|
264 |
+
def _build_optimizer(self):
|
265 |
+
scheduler_params = {
|
266 |
+
"warmup_steps": self.cfg.loss_params.warmup_steps,
|
267 |
+
"base_lr": self.cfg.loss_params.base_lr,
|
268 |
+
}
|
269 |
+
optimizer = build_optimizer(
|
270 |
+
{key: self.model[key] for key in self.model},
|
271 |
+
scheduler_params_dict={key: scheduler_params.copy() for key in self.model},
|
272 |
+
lr=float(scheduler_params["base_lr"]),
|
273 |
+
)
|
274 |
+
|
275 |
+
return optimizer
|
276 |
+
|
277 |
+
def train_loop(self):
|
278 |
+
"""Training process"""
|
279 |
+
self.accelerator.wait_for_everyone()
|
280 |
+
|
281 |
+
# Dump config
|
282 |
+
if self.accelerator.is_main_process:
|
283 |
+
self._dump_cfg(self.config_save_path)
|
284 |
+
_ = [self.model[key].train() for key in self.model]
|
285 |
+
self.optimizer.zero_grad()
|
286 |
+
|
287 |
+
# Sync and start training
|
288 |
+
self.accelerator.wait_for_everyone()
|
289 |
+
while self.epoch < self.max_epoch:
|
290 |
+
self.logger.info("\n")
|
291 |
+
self.logger.info("-" * 32)
|
292 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
293 |
+
|
294 |
+
# Train and Validate
|
295 |
+
train_total_loss, train_losses = self._train_epoch()
|
296 |
+
for key, loss in train_losses.items():
|
297 |
+
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
|
298 |
+
self.accelerator.log(
|
299 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
300 |
+
step=self.epoch,
|
301 |
+
)
|
302 |
+
self.accelerator.log(
|
303 |
+
{
|
304 |
+
"Epoch/Train Total Loss": train_total_loss,
|
305 |
+
},
|
306 |
+
step=self.epoch,
|
307 |
+
)
|
308 |
+
|
309 |
+
# Update scheduler
|
310 |
+
self.accelerator.wait_for_everyone()
|
311 |
+
|
312 |
+
# Check save checkpoint interval
|
313 |
+
run_eval = False
|
314 |
+
if self.accelerator.is_main_process:
|
315 |
+
save_checkpoint = False
|
316 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
317 |
+
if self.epoch % num == 0:
|
318 |
+
save_checkpoint = True
|
319 |
+
run_eval |= self.run_eval[i]
|
320 |
+
|
321 |
+
# Save checkpoints
|
322 |
+
self.accelerator.wait_for_everyone()
|
323 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
324 |
+
print("Saving..")
|
325 |
+
state = {
|
326 |
+
"net": {key: self.model[key].state_dict() for key in self.model},
|
327 |
+
"optimizer": self.optimizer.state_dict(),
|
328 |
+
"scheduler": self.optimizer.scheduler_state_dict(),
|
329 |
+
"iters": self.step,
|
330 |
+
"epoch": self.epoch,
|
331 |
+
}
|
332 |
+
save_path = os.path.join(
|
333 |
+
self.checkpoint_dir,
|
334 |
+
"FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
|
335 |
+
)
|
336 |
+
torch.save(state, save_path)
|
337 |
+
json.dump(
|
338 |
+
self.checkpoints_path,
|
339 |
+
open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"),
|
340 |
+
ensure_ascii=False,
|
341 |
+
indent=4,
|
342 |
+
)
|
343 |
+
|
344 |
+
self.accelerator.wait_for_everyone()
|
345 |
+
|
346 |
+
self.epoch += 1
|
347 |
+
|
348 |
+
# Finish training
|
349 |
+
self.accelerator.wait_for_everyone()
|
350 |
+
if self.accelerator.is_main_process:
|
351 |
+
path = os.path.join(
|
352 |
+
self.checkpoint_dir,
|
353 |
+
"epoch-{:04d}_step-{:07d}".format(
|
354 |
+
self.epoch,
|
355 |
+
self.step,
|
356 |
+
),
|
357 |
+
)
|
358 |
+
print("Saving..")
|
359 |
+
state = {
|
360 |
+
"net": {key: self.model[key].state_dict() for key in self.model},
|
361 |
+
"optimizer": self.optimizer.state_dict(),
|
362 |
+
"scheduler": self.optimizer.scheduler_state_dict(),
|
363 |
+
"iters": self.step,
|
364 |
+
"epoch": self.epoch,
|
365 |
+
}
|
366 |
+
save_path = os.path.join(
|
367 |
+
self.checkpoint_dir,
|
368 |
+
"FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
|
369 |
+
)
|
370 |
+
torch.save(state, save_path)
|
371 |
+
|
372 |
+
def _train_epoch(self):
|
373 |
+
"""Training epoch. Should return average loss of a batch (sample) over
|
374 |
+
one epoch. See ``train_loop`` for usage.
|
375 |
+
"""
|
376 |
+
_ = [self.model[key].train() for key in self.model]
|
377 |
+
|
378 |
+
epoch_losses: dict = {}
|
379 |
+
epoch_total_loss: int = 0
|
380 |
+
|
381 |
+
for batch in tqdm(
|
382 |
+
self.train_dataloader,
|
383 |
+
desc=f"Training Epoch {self.epoch}",
|
384 |
+
unit="batch",
|
385 |
+
colour="GREEN",
|
386 |
+
leave=False,
|
387 |
+
dynamic_ncols=True,
|
388 |
+
smoothing=0.04,
|
389 |
+
disable=not self.accelerator.is_main_process,
|
390 |
+
):
|
391 |
+
# Get losses
|
392 |
+
total_loss, losses = self._train_step(batch)
|
393 |
+
self.batch_count += 1
|
394 |
+
|
395 |
+
# Log info
|
396 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
397 |
+
self.accelerator.log(
|
398 |
+
{
|
399 |
+
"Step/Learning Rate": (
|
400 |
+
self.optimizer.schedulers["encoder"].get_last_lr()[0]
|
401 |
+
if self.step != 0
|
402 |
+
else 0
|
403 |
+
)
|
404 |
+
},
|
405 |
+
step=self.step,
|
406 |
+
)
|
407 |
+
for key, _ in losses.items():
|
408 |
+
self.accelerator.log(
|
409 |
+
{
|
410 |
+
"Step/Train {} Loss".format(key): losses[key],
|
411 |
+
},
|
412 |
+
step=self.step,
|
413 |
+
)
|
414 |
+
|
415 |
+
if not epoch_losses:
|
416 |
+
epoch_losses = losses
|
417 |
+
else:
|
418 |
+
for key, value in losses.items():
|
419 |
+
epoch_losses[key] += value
|
420 |
+
epoch_total_loss += total_loss
|
421 |
+
self.step += 1
|
422 |
+
|
423 |
+
# Get and log total losses
|
424 |
+
self.accelerator.wait_for_everyone()
|
425 |
+
epoch_total_loss = (
|
426 |
+
epoch_total_loss
|
427 |
+
/ len(self.train_dataloader)
|
428 |
+
* self.cfg.train.gradient_accumulation_step
|
429 |
+
)
|
430 |
+
for key in epoch_losses.keys():
|
431 |
+
epoch_losses[key] = (
|
432 |
+
epoch_losses[key]
|
433 |
+
/ len(self.train_dataloader)
|
434 |
+
* self.cfg.train.gradient_accumulation_step
|
435 |
+
)
|
436 |
+
return epoch_total_loss, epoch_losses
|
437 |
+
|
438 |
+
def _train_step(self, data):
|
439 |
+
"""Training forward step. Should return average loss of a sample over
|
440 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
441 |
+
See ``_train_epoch`` for usage.
|
442 |
+
"""
|
443 |
+
# Init losses
|
444 |
+
train_losses = {}
|
445 |
+
total_loss = 0
|
446 |
+
|
447 |
+
# Use input feature to get predictions
|
448 |
+
data = [b.to(self.accelerator.device, non_blocking=True) for b in data]
|
449 |
+
waves, mels, wave_lengths, mel_input_length = data
|
450 |
+
|
451 |
+
# extract semantic latent with w2v model
|
452 |
+
waves_16k = torchaudio.functional.resample(waves, 24000, 16000)
|
453 |
+
w2v_input = self.w2v_processor(
|
454 |
+
waves_16k, sampling_rate=16000, return_tensors="pt"
|
455 |
+
).input_values.to(self.accelerator.device)
|
456 |
+
with torch.no_grad():
|
457 |
+
w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits
|
458 |
+
predicted_ids = torch.argmax(w2v_outputs, dim=-1)
|
459 |
+
phone_ids = (
|
460 |
+
F.interpolate(
|
461 |
+
predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest"
|
462 |
+
)
|
463 |
+
.long()
|
464 |
+
.squeeze(0)
|
465 |
+
)
|
466 |
+
|
467 |
+
# get clips
|
468 |
+
mel_seg_len = min(
|
469 |
+
[int(mel_input_length.min().item()), self.cfg.train.max_frame_len]
|
470 |
+
)
|
471 |
+
|
472 |
+
gt_mel_seg = []
|
473 |
+
wav_seg = []
|
474 |
+
w2v_seg = []
|
475 |
+
|
476 |
+
for bib in range(len(mel_input_length)):
|
477 |
+
mel_length = int(mel_input_length[bib].item())
|
478 |
+
|
479 |
+
random_start = (
|
480 |
+
np.random.randint(0, mel_length - mel_seg_len)
|
481 |
+
if mel_length != mel_seg_len
|
482 |
+
else 0
|
483 |
+
)
|
484 |
+
gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len])
|
485 |
+
|
486 |
+
# w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len])
|
487 |
+
w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len])
|
488 |
+
|
489 |
+
y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300]
|
490 |
+
|
491 |
+
wav_seg.append(y.to(self.accelerator.device))
|
492 |
+
|
493 |
+
gt_mel_seg = torch.stack(gt_mel_seg).detach()
|
494 |
+
|
495 |
+
wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1)
|
496 |
+
w2v_seg = torch.stack(w2v_seg).float().detach()
|
497 |
+
|
498 |
+
with torch.no_grad():
|
499 |
+
real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach()
|
500 |
+
F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1))
|
501 |
+
|
502 |
+
# normalize f0
|
503 |
+
# Remove unvoiced frames (replace with -1)
|
504 |
+
gt_glob_f0s = []
|
505 |
+
f0_targets = []
|
506 |
+
for bib in range(len(F0_real)):
|
507 |
+
voiced_indices = F0_real[bib] > 5.0
|
508 |
+
f0_voiced = F0_real[bib][voiced_indices]
|
509 |
+
|
510 |
+
if len(f0_voiced) != 0:
|
511 |
+
# Convert to log scale
|
512 |
+
log_f0 = f0_voiced.log2()
|
513 |
+
|
514 |
+
# Calculate mean and standard deviation
|
515 |
+
mean_f0 = log_f0.mean()
|
516 |
+
std_f0 = log_f0.std()
|
517 |
+
|
518 |
+
# Normalize the F0 sequence
|
519 |
+
normalized_f0 = (log_f0 - mean_f0) / std_f0
|
520 |
+
|
521 |
+
# Create the normalized F0 sequence with unvoiced frames
|
522 |
+
normalized_sequence = torch.zeros_like(F0_real[bib])
|
523 |
+
normalized_sequence[voiced_indices] = normalized_f0
|
524 |
+
normalized_sequence[~voiced_indices] = (
|
525 |
+
-10
|
526 |
+
) # Assign -10 to unvoiced frames
|
527 |
+
|
528 |
+
gt_glob_f0s.append(mean_f0)
|
529 |
+
else:
|
530 |
+
normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0
|
531 |
+
gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device))
|
532 |
+
|
533 |
+
# f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200])
|
534 |
+
f0_targets.append(normalized_sequence)
|
535 |
+
f0_targets = torch.stack(f0_targets).to(self.accelerator.device)
|
536 |
+
# fill nan with -10
|
537 |
+
f0_targets[torch.isnan(f0_targets)] = -10.0
|
538 |
+
# fill inf with -10
|
539 |
+
f0_targets[torch.isinf(f0_targets)] = -10.0
|
540 |
+
# if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate
|
541 |
+
if self.cfg.preprocess_params.frame_rate != 80:
|
542 |
+
f0_targets = F.interpolate(
|
543 |
+
f0_targets.unsqueeze(1),
|
544 |
+
mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
|
545 |
+
mode="nearest",
|
546 |
+
).squeeze(1)
|
547 |
+
w2v_seg = F.interpolate(
|
548 |
+
w2v_seg,
|
549 |
+
mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
|
550 |
+
mode="nearest",
|
551 |
+
)
|
552 |
+
|
553 |
+
wav_seg_input = wav_seg
|
554 |
+
wav_seg_target = wav_seg
|
555 |
+
|
556 |
+
z = self.model.encoder(wav_seg_input)
|
557 |
+
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
|
558 |
+
z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths
|
559 |
+
)
|
560 |
+
preds, rev_preds = self.model.fa_predictors(quantized, timbre)
|
561 |
+
|
562 |
+
pred_wave = self.model.decoder(z)
|
563 |
+
|
564 |
+
len_diff = wav_seg_target.size(-1) - pred_wave.size(-1)
|
565 |
+
if len_diff > 0:
|
566 |
+
wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2]
|
567 |
+
|
568 |
+
# discriminator loss
|
569 |
+
d_fake = self.model.discriminator(pred_wave.detach())
|
570 |
+
d_real = self.model.discriminator(wav_seg_target)
|
571 |
+
loss_d = 0
|
572 |
+
for x_fake, x_real in zip(d_fake, d_real):
|
573 |
+
loss_d += torch.mean(x_fake[-1] ** 2)
|
574 |
+
loss_d += torch.mean((1 - x_real[-1]) ** 2)
|
575 |
+
|
576 |
+
self.optimizer.zero_grad()
|
577 |
+
self.accelerator.backward(loss_d)
|
578 |
+
grad_norm_d = torch.nn.utils.clip_grad_norm_(
|
579 |
+
self.model.discriminator.parameters(), 10.0
|
580 |
+
)
|
581 |
+
self.optimizer.step("discriminator")
|
582 |
+
self.optimizer.scheduler(key="discriminator")
|
583 |
+
|
584 |
+
# generator loss
|
585 |
+
signal = AudioSignal(wav_seg_target, sample_rate=24000)
|
586 |
+
recons = AudioSignal(pred_wave, sample_rate=24000)
|
587 |
+
stft_loss = self.criterions["stft"](recons, signal)
|
588 |
+
mel_loss = self.criterions["mel"](recons, signal)
|
589 |
+
waveform_loss = self.criterions["l1"](recons, signal)
|
590 |
+
|
591 |
+
d_fake = self.model.discriminator(pred_wave)
|
592 |
+
d_real = self.model.discriminator(wav_seg_target)
|
593 |
+
|
594 |
+
loss_g = 0
|
595 |
+
for x_fake in d_fake:
|
596 |
+
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
|
597 |
+
|
598 |
+
loss_feature = 0
|
599 |
+
|
600 |
+
for i in range(len(d_fake)):
|
601 |
+
for j in range(len(d_fake[i]) - 1):
|
602 |
+
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
|
603 |
+
|
604 |
+
pred_f0, pred_uv = preds["f0"], preds["uv"]
|
605 |
+
rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"]
|
606 |
+
|
607 |
+
common_min_size = min(pred_f0.size(-2), f0_targets.size(-1))
|
608 |
+
f0_targets = f0_targets[..., :common_min_size]
|
609 |
+
real_norm = real_norm[..., :common_min_size]
|
610 |
+
|
611 |
+
f0_loss = F.smooth_l1_loss(
|
612 |
+
f0_targets, pred_f0.squeeze(-1)[..., :common_min_size]
|
613 |
+
)
|
614 |
+
uv_loss = F.smooth_l1_loss(
|
615 |
+
real_norm, pred_uv.squeeze(-1)[..., :common_min_size]
|
616 |
+
)
|
617 |
+
rev_f0_loss = (
|
618 |
+
F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size])
|
619 |
+
if rev_pred_f0 is not None
|
620 |
+
else torch.FloatTensor([0]).to(self.accelerator.device)
|
621 |
+
)
|
622 |
+
rev_uv_loss = (
|
623 |
+
F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size])
|
624 |
+
if rev_pred_uv is not None
|
625 |
+
else torch.FloatTensor([0]).to(self.accelerator.device)
|
626 |
+
)
|
627 |
+
|
628 |
+
tot_f0_loss = f0_loss + rev_f0_loss
|
629 |
+
tot_uv_loss = uv_loss + rev_uv_loss
|
630 |
+
|
631 |
+
pred_content = preds["content"]
|
632 |
+
rev_pred_content = rev_preds["rev_content"]
|
633 |
+
|
634 |
+
target_content_latents = w2v_seg[..., :common_min_size]
|
635 |
+
|
636 |
+
content_loss = self.criterions["content"](
|
637 |
+
pred_content.transpose(1, 2)[..., :common_min_size],
|
638 |
+
target_content_latents.long(),
|
639 |
+
)
|
640 |
+
rev_content_loss = (
|
641 |
+
self.criterions["content"](
|
642 |
+
rev_pred_content.transpose(1, 2)[..., :common_min_size],
|
643 |
+
target_content_latents.long(),
|
644 |
+
)
|
645 |
+
if rev_pred_content is not None
|
646 |
+
else torch.FloatTensor([0]).to(self.accelerator.device)
|
647 |
+
)
|
648 |
+
|
649 |
+
tot_content_loss = content_loss + rev_content_loss
|
650 |
+
|
651 |
+
if self.speaker_model is not None:
|
652 |
+
spk_logits = torch.cat(
|
653 |
+
[
|
654 |
+
self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1]
|
655 |
+
for w16, wl in zip(waves_16k, wave_lengths)
|
656 |
+
],
|
657 |
+
dim=0,
|
658 |
+
)
|
659 |
+
spk_labels = spk_logits.argmax(dim=-1)
|
660 |
+
else:
|
661 |
+
spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to(
|
662 |
+
self.accelerator.device
|
663 |
+
)
|
664 |
+
|
665 |
+
spk_pred_logits = preds["timbre"]
|
666 |
+
spk_loss = F.cross_entropy(spk_pred_logits, spk_labels)
|
667 |
+
x_spk_pred_logits = rev_preds["x_timbre"]
|
668 |
+
|
669 |
+
x_spk_loss = (
|
670 |
+
F.cross_entropy(x_spk_pred_logits, spk_labels)
|
671 |
+
if x_spk_pred_logits is not None
|
672 |
+
else torch.FloatTensor([0]).to(self.accelerator.device)
|
673 |
+
)
|
674 |
+
|
675 |
+
tot_spk_loss = spk_loss + x_spk_loss
|
676 |
+
|
677 |
+
loss_gen_all = (
|
678 |
+
mel_loss * 15.0
|
679 |
+
+ loss_feature * 1.0
|
680 |
+
+ loss_g * 1.0
|
681 |
+
+ commitment_loss * 0.25
|
682 |
+
+ codebook_loss * 1.0
|
683 |
+
+ tot_f0_loss * 1.0
|
684 |
+
+ tot_uv_loss * 1.0
|
685 |
+
+ tot_content_loss * 5.0
|
686 |
+
+ tot_spk_loss * 5.0
|
687 |
+
)
|
688 |
+
|
689 |
+
self.optimizer.zero_grad()
|
690 |
+
self.accelerator.backward(loss_gen_all)
|
691 |
+
|
692 |
+
with torch.no_grad():
|
693 |
+
total_loss = loss_gen_all.item()
|
694 |
+
train_losses["stft"] = stft_loss.item()
|
695 |
+
train_losses["mel"] = mel_loss.item()
|
696 |
+
train_losses["l1"] = waveform_loss.item()
|
697 |
+
train_losses["f0"] = f0_loss.item()
|
698 |
+
train_losses["uv"] = uv_loss.item()
|
699 |
+
train_losses["content"] = content_loss.item()
|
700 |
+
train_losses["speaker"] = spk_loss.item()
|
701 |
+
train_losses["rev_f0"] = rev_f0_loss.item()
|
702 |
+
train_losses["rev_uv"] = rev_uv_loss.item()
|
703 |
+
train_losses["rev_content"] = rev_content_loss.item()
|
704 |
+
train_losses["rev_speaker"] = x_spk_loss.item()
|
705 |
+
|
706 |
+
train_losses["feature"] = loss_feature.item()
|
707 |
+
train_losses["generator"] = loss_g.item()
|
708 |
+
train_losses["commitment"] = commitment_loss.item()
|
709 |
+
train_losses["codebook"] = codebook_loss.item()
|
710 |
+
|
711 |
+
# discriminators
|
712 |
+
train_losses["discriminator"] = loss_d.item()
|
713 |
+
|
714 |
+
return total_loss, train_losses
|
715 |
+
|
716 |
+
def _inference(self, eval_wave):
|
717 |
+
"""Inference during training for test audios."""
|
718 |
+
z = self.model.encoder(
|
719 |
+
eval_wave[None, None, ...].to(self.accelerator.device).float()
|
720 |
+
)
|
721 |
+
z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
|
722 |
+
z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks
|
723 |
+
)
|
724 |
+
full_pred_wave = self.model.decoder(z)
|
725 |
+
return full_pred_wave[0]
|
726 |
+
|
727 |
+
def _load_model(self, checkpoint_path=None, resume_type="resume"):
|
728 |
+
"""Load model from checkpoint. If checkpoint_path is None, it will
|
729 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
730 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
731 |
+
method after** ``accelerator.prepare()``.
|
732 |
+
"""
|
733 |
+
if resume_type == "resume":
|
734 |
+
if checkpoint_path is None:
|
735 |
+
available_checkpoints = glob.glob(
|
736 |
+
os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth")
|
737 |
+
)
|
738 |
+
# find the checkpoint that has the highest step number
|
739 |
+
latest_checkpoint = max(
|
740 |
+
available_checkpoints,
|
741 |
+
key=lambda x: int(x.split("_")[-1].split(".")[0]),
|
742 |
+
)
|
743 |
+
earliest_checkpoint = min(
|
744 |
+
available_checkpoints,
|
745 |
+
key=lambda x: int(x.split("_")[-1].split(".")[0]),
|
746 |
+
)
|
747 |
+
# delete the earliest checkpoint
|
748 |
+
if (
|
749 |
+
earliest_checkpoint != latest_checkpoint
|
750 |
+
and self.accelerator.is_main_process
|
751 |
+
and len(available_checkpoints) > 4
|
752 |
+
):
|
753 |
+
os.remove(earliest_checkpoint)
|
754 |
+
print(f"Removed {earliest_checkpoint}")
|
755 |
+
else:
|
756 |
+
latest_checkpoint = checkpoint_path
|
757 |
+
|
758 |
+
self.model, self.optimizer, self.epoch, self.step = load_checkpoint(
|
759 |
+
self.model,
|
760 |
+
self.optimizer,
|
761 |
+
latest_checkpoint,
|
762 |
+
load_only_params=False,
|
763 |
+
ignore_modules=[],
|
764 |
+
is_distributed=self.accelerator.num_processes > 1,
|
765 |
+
)
|
766 |
+
|
767 |
+
else:
|
768 |
+
raise ValueError("Invalid resume type")
|
769 |
+
return checkpoint_path
|
770 |
+
|
771 |
+
def _count_parameters(self):
|
772 |
+
total_num = sum(
|
773 |
+
sum(p.numel() for p in self.model[key].parameters()) for key in self.model
|
774 |
+
)
|
775 |
+
# trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
776 |
+
return total_num
|
models/codec/facodec/modules/JDC/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
models/codec/facodec/modules/JDC/bst.t7
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
|
3 |
+
size 21029926
|
models/codec/facodec/modules/JDC/model.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
|
7 |
+
|
8 |
+
"""
|
9 |
+
Implementation of model from:
|
10 |
+
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
|
11 |
+
Convolutional Recurrent Neural Networks" (2019)
|
12 |
+
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
|
13 |
+
"""
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
class JDCNet(nn.Module):
|
19 |
+
"""
|
20 |
+
Joint Detection and Classification Network model for singing voice melody.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
|
24 |
+
super().__init__()
|
25 |
+
self.num_class = num_class
|
26 |
+
|
27 |
+
# input = (b, 1, 31, 513), b = batch size
|
28 |
+
self.conv_block = nn.Sequential(
|
29 |
+
nn.Conv2d(
|
30 |
+
in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
|
31 |
+
), # out: (b, 64, 31, 513)
|
32 |
+
nn.BatchNorm2d(num_features=64),
|
33 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
34 |
+
nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
|
35 |
+
)
|
36 |
+
|
37 |
+
# res blocks
|
38 |
+
self.res_block1 = ResBlock(
|
39 |
+
in_channels=64, out_channels=128
|
40 |
+
) # (b, 128, 31, 128)
|
41 |
+
self.res_block2 = ResBlock(
|
42 |
+
in_channels=128, out_channels=192
|
43 |
+
) # (b, 192, 31, 32)
|
44 |
+
self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
|
45 |
+
|
46 |
+
# pool block
|
47 |
+
self.pool_block = nn.Sequential(
|
48 |
+
nn.BatchNorm2d(num_features=256),
|
49 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
50 |
+
nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
|
51 |
+
nn.Dropout(p=0.2),
|
52 |
+
)
|
53 |
+
|
54 |
+
# maxpool layers (for auxiliary network inputs)
|
55 |
+
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
|
56 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
|
57 |
+
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
|
58 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
|
59 |
+
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
|
60 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
|
61 |
+
|
62 |
+
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
|
63 |
+
self.detector_conv = nn.Sequential(
|
64 |
+
nn.Conv2d(640, 256, 1, bias=False),
|
65 |
+
nn.BatchNorm2d(256),
|
66 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
67 |
+
nn.Dropout(p=0.2),
|
68 |
+
)
|
69 |
+
|
70 |
+
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
71 |
+
self.bilstm_classifier = nn.LSTM(
|
72 |
+
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
|
73 |
+
) # (b, 31, 512)
|
74 |
+
|
75 |
+
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
|
76 |
+
self.bilstm_detector = nn.LSTM(
|
77 |
+
input_size=512, hidden_size=256, batch_first=True, bidirectional=True
|
78 |
+
) # (b, 31, 512)
|
79 |
+
|
80 |
+
# input: (b * 31, 512)
|
81 |
+
self.classifier = nn.Linear(
|
82 |
+
in_features=512, out_features=self.num_class
|
83 |
+
) # (b * 31, num_class)
|
84 |
+
|
85 |
+
# input: (b * 31, 512)
|
86 |
+
self.detector = nn.Linear(
|
87 |
+
in_features=512, out_features=2
|
88 |
+
) # (b * 31, 2) - binary classifier
|
89 |
+
|
90 |
+
# initialize weights
|
91 |
+
self.apply(self.init_weights)
|
92 |
+
|
93 |
+
def get_feature_GAN(self, x):
|
94 |
+
seq_len = x.shape[-2]
|
95 |
+
x = x.float().transpose(-1, -2)
|
96 |
+
|
97 |
+
convblock_out = self.conv_block(x)
|
98 |
+
|
99 |
+
resblock1_out = self.res_block1(convblock_out)
|
100 |
+
resblock2_out = self.res_block2(resblock1_out)
|
101 |
+
resblock3_out = self.res_block3(resblock2_out)
|
102 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
103 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
104 |
+
|
105 |
+
return poolblock_out.transpose(-1, -2)
|
106 |
+
|
107 |
+
def get_feature(self, x):
|
108 |
+
seq_len = x.shape[-2]
|
109 |
+
x = x.float().transpose(-1, -2)
|
110 |
+
|
111 |
+
convblock_out = self.conv_block(x)
|
112 |
+
|
113 |
+
resblock1_out = self.res_block1(convblock_out)
|
114 |
+
resblock2_out = self.res_block2(resblock1_out)
|
115 |
+
resblock3_out = self.res_block3(resblock2_out)
|
116 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
117 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
118 |
+
|
119 |
+
return self.pool_block[2](poolblock_out)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
"""
|
123 |
+
Returns:
|
124 |
+
classification_prediction, detection_prediction
|
125 |
+
sizes: (b, 31, 722), (b, 31, 2)
|
126 |
+
"""
|
127 |
+
###############################
|
128 |
+
# forward pass for classifier #
|
129 |
+
###############################
|
130 |
+
seq_len = x.shape[-1]
|
131 |
+
x = x.float().transpose(-1, -2)
|
132 |
+
|
133 |
+
convblock_out = self.conv_block(x)
|
134 |
+
|
135 |
+
resblock1_out = self.res_block1(convblock_out)
|
136 |
+
resblock2_out = self.res_block2(resblock1_out)
|
137 |
+
resblock3_out = self.res_block3(resblock2_out)
|
138 |
+
|
139 |
+
poolblock_out = self.pool_block[0](resblock3_out)
|
140 |
+
poolblock_out = self.pool_block[1](poolblock_out)
|
141 |
+
GAN_feature = poolblock_out.transpose(-1, -2)
|
142 |
+
poolblock_out = self.pool_block[2](poolblock_out)
|
143 |
+
|
144 |
+
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
|
145 |
+
classifier_out = (
|
146 |
+
poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
|
147 |
+
)
|
148 |
+
classifier_out, _ = self.bilstm_classifier(
|
149 |
+
classifier_out
|
150 |
+
) # ignore the hidden states
|
151 |
+
|
152 |
+
classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
|
153 |
+
classifier_out = self.classifier(classifier_out)
|
154 |
+
classifier_out = classifier_out.view(
|
155 |
+
(-1, seq_len, self.num_class)
|
156 |
+
) # (b, 31, num_class)
|
157 |
+
|
158 |
+
# sizes: (b, 31, 722), (b, 31, 2)
|
159 |
+
# classifier output consists of predicted pitch classes per frame
|
160 |
+
# detector output consists of: (isvoice, notvoice) estimates per frame
|
161 |
+
return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def init_weights(m):
|
165 |
+
if isinstance(m, nn.Linear):
|
166 |
+
nn.init.kaiming_uniform_(m.weight)
|
167 |
+
if m.bias is not None:
|
168 |
+
nn.init.constant_(m.bias, 0)
|
169 |
+
elif isinstance(m, nn.Conv2d):
|
170 |
+
nn.init.xavier_normal_(m.weight)
|
171 |
+
elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
|
172 |
+
for p in m.parameters():
|
173 |
+
if p.data is None:
|
174 |
+
continue
|
175 |
+
|
176 |
+
if len(p.shape) >= 2:
|
177 |
+
nn.init.orthogonal_(p.data)
|
178 |
+
else:
|
179 |
+
nn.init.normal_(p.data)
|
180 |
+
|
181 |
+
|
182 |
+
class ResBlock(nn.Module):
|
183 |
+
def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
|
184 |
+
super().__init__()
|
185 |
+
self.downsample = in_channels != out_channels
|
186 |
+
|
187 |
+
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
|
188 |
+
self.pre_conv = nn.Sequential(
|
189 |
+
nn.BatchNorm2d(num_features=in_channels),
|
190 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
191 |
+
nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
|
192 |
+
)
|
193 |
+
|
194 |
+
# conv layers
|
195 |
+
self.conv = nn.Sequential(
|
196 |
+
nn.Conv2d(
|
197 |
+
in_channels=in_channels,
|
198 |
+
out_channels=out_channels,
|
199 |
+
kernel_size=3,
|
200 |
+
padding=1,
|
201 |
+
bias=False,
|
202 |
+
),
|
203 |
+
nn.BatchNorm2d(out_channels),
|
204 |
+
nn.LeakyReLU(leaky_relu_slope, inplace=True),
|
205 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
206 |
+
)
|
207 |
+
|
208 |
+
# 1 x 1 convolution layer to match the feature dimensions
|
209 |
+
self.conv1by1 = None
|
210 |
+
if self.downsample:
|
211 |
+
self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
x = self.pre_conv(x)
|
215 |
+
if self.downsample:
|
216 |
+
x = self.conv(x) + self.conv1by1(x)
|
217 |
+
else:
|
218 |
+
x = self.conv(x) + x
|
219 |
+
return x
|
models/codec/facodec/modules/attentions.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
|
7 |
+
|
8 |
+
import copy
|
9 |
+
import math
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from . import commons
|
16 |
+
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
def __init__(self, channels, eps=1e-5):
|
20 |
+
super().__init__()
|
21 |
+
self.channels = channels
|
22 |
+
self.eps = eps
|
23 |
+
|
24 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
25 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = x.transpose(1, -1)
|
29 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
30 |
+
return x.transpose(1, -1)
|
31 |
+
|
32 |
+
|
33 |
+
class Encoder(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
hidden_channels,
|
37 |
+
filter_channels,
|
38 |
+
n_heads,
|
39 |
+
n_layers,
|
40 |
+
kernel_size=1,
|
41 |
+
p_dropout=0.0,
|
42 |
+
window_size=4,
|
43 |
+
**kwargs
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.hidden_channels = hidden_channels
|
47 |
+
self.filter_channels = filter_channels
|
48 |
+
self.n_heads = n_heads
|
49 |
+
self.n_layers = n_layers
|
50 |
+
self.kernel_size = kernel_size
|
51 |
+
self.p_dropout = p_dropout
|
52 |
+
self.window_size = window_size
|
53 |
+
|
54 |
+
self.drop = nn.Dropout(p_dropout)
|
55 |
+
self.attn_layers = nn.ModuleList()
|
56 |
+
self.norm_layers_1 = nn.ModuleList()
|
57 |
+
self.ffn_layers = nn.ModuleList()
|
58 |
+
self.norm_layers_2 = nn.ModuleList()
|
59 |
+
for i in range(self.n_layers):
|
60 |
+
self.attn_layers.append(
|
61 |
+
MultiHeadAttention(
|
62 |
+
hidden_channels,
|
63 |
+
hidden_channels,
|
64 |
+
n_heads,
|
65 |
+
p_dropout=p_dropout,
|
66 |
+
window_size=window_size,
|
67 |
+
)
|
68 |
+
)
|
69 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
70 |
+
self.ffn_layers.append(
|
71 |
+
FFN(
|
72 |
+
hidden_channels,
|
73 |
+
hidden_channels,
|
74 |
+
filter_channels,
|
75 |
+
kernel_size,
|
76 |
+
p_dropout=p_dropout,
|
77 |
+
)
|
78 |
+
)
|
79 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
80 |
+
|
81 |
+
def forward(self, x, x_mask):
|
82 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
83 |
+
x = x * x_mask
|
84 |
+
for i in range(self.n_layers):
|
85 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
86 |
+
y = self.drop(y)
|
87 |
+
x = self.norm_layers_1[i](x + y)
|
88 |
+
|
89 |
+
y = self.ffn_layers[i](x, x_mask)
|
90 |
+
y = self.drop(y)
|
91 |
+
x = self.norm_layers_2[i](x + y)
|
92 |
+
x = x * x_mask
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
class Decoder(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
hidden_channels,
|
100 |
+
filter_channels,
|
101 |
+
n_heads,
|
102 |
+
n_layers,
|
103 |
+
kernel_size=1,
|
104 |
+
p_dropout=0.0,
|
105 |
+
proximal_bias=False,
|
106 |
+
proximal_init=True,
|
107 |
+
**kwargs
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.hidden_channels = hidden_channels
|
111 |
+
self.filter_channels = filter_channels
|
112 |
+
self.n_heads = n_heads
|
113 |
+
self.n_layers = n_layers
|
114 |
+
self.kernel_size = kernel_size
|
115 |
+
self.p_dropout = p_dropout
|
116 |
+
self.proximal_bias = proximal_bias
|
117 |
+
self.proximal_init = proximal_init
|
118 |
+
|
119 |
+
self.drop = nn.Dropout(p_dropout)
|
120 |
+
self.self_attn_layers = nn.ModuleList()
|
121 |
+
self.norm_layers_0 = nn.ModuleList()
|
122 |
+
self.encdec_attn_layers = nn.ModuleList()
|
123 |
+
self.norm_layers_1 = nn.ModuleList()
|
124 |
+
self.ffn_layers = nn.ModuleList()
|
125 |
+
self.norm_layers_2 = nn.ModuleList()
|
126 |
+
for i in range(self.n_layers):
|
127 |
+
self.self_attn_layers.append(
|
128 |
+
MultiHeadAttention(
|
129 |
+
hidden_channels,
|
130 |
+
hidden_channels,
|
131 |
+
n_heads,
|
132 |
+
p_dropout=p_dropout,
|
133 |
+
proximal_bias=proximal_bias,
|
134 |
+
proximal_init=proximal_init,
|
135 |
+
)
|
136 |
+
)
|
137 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
138 |
+
self.encdec_attn_layers.append(
|
139 |
+
MultiHeadAttention(
|
140 |
+
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
141 |
+
)
|
142 |
+
)
|
143 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
144 |
+
self.ffn_layers.append(
|
145 |
+
FFN(
|
146 |
+
hidden_channels,
|
147 |
+
hidden_channels,
|
148 |
+
filter_channels,
|
149 |
+
kernel_size,
|
150 |
+
p_dropout=p_dropout,
|
151 |
+
causal=True,
|
152 |
+
)
|
153 |
+
)
|
154 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
155 |
+
|
156 |
+
def forward(self, x, x_mask, h, h_mask):
|
157 |
+
"""
|
158 |
+
x: decoder input
|
159 |
+
h: encoder output
|
160 |
+
"""
|
161 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
162 |
+
device=x.device, dtype=x.dtype
|
163 |
+
)
|
164 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
165 |
+
x = x * x_mask
|
166 |
+
for i in range(self.n_layers):
|
167 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
168 |
+
y = self.drop(y)
|
169 |
+
x = self.norm_layers_0[i](x + y)
|
170 |
+
|
171 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
172 |
+
y = self.drop(y)
|
173 |
+
x = self.norm_layers_1[i](x + y)
|
174 |
+
|
175 |
+
y = self.ffn_layers[i](x, x_mask)
|
176 |
+
y = self.drop(y)
|
177 |
+
x = self.norm_layers_2[i](x + y)
|
178 |
+
x = x * x_mask
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class MultiHeadAttention(nn.Module):
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
channels,
|
186 |
+
out_channels,
|
187 |
+
n_heads,
|
188 |
+
p_dropout=0.0,
|
189 |
+
window_size=None,
|
190 |
+
heads_share=True,
|
191 |
+
block_length=None,
|
192 |
+
proximal_bias=False,
|
193 |
+
proximal_init=False,
|
194 |
+
):
|
195 |
+
super().__init__()
|
196 |
+
assert channels % n_heads == 0
|
197 |
+
|
198 |
+
self.channels = channels
|
199 |
+
self.out_channels = out_channels
|
200 |
+
self.n_heads = n_heads
|
201 |
+
self.p_dropout = p_dropout
|
202 |
+
self.window_size = window_size
|
203 |
+
self.heads_share = heads_share
|
204 |
+
self.block_length = block_length
|
205 |
+
self.proximal_bias = proximal_bias
|
206 |
+
self.proximal_init = proximal_init
|
207 |
+
self.attn = None
|
208 |
+
|
209 |
+
self.k_channels = channels // n_heads
|
210 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
211 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
212 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
213 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
214 |
+
self.drop = nn.Dropout(p_dropout)
|
215 |
+
|
216 |
+
if window_size is not None:
|
217 |
+
n_heads_rel = 1 if heads_share else n_heads
|
218 |
+
rel_stddev = self.k_channels**-0.5
|
219 |
+
self.emb_rel_k = nn.Parameter(
|
220 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
221 |
+
* rel_stddev
|
222 |
+
)
|
223 |
+
self.emb_rel_v = nn.Parameter(
|
224 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
225 |
+
* rel_stddev
|
226 |
+
)
|
227 |
+
|
228 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
229 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
230 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
231 |
+
if proximal_init:
|
232 |
+
with torch.no_grad():
|
233 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
234 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
235 |
+
|
236 |
+
def forward(self, x, c, attn_mask=None):
|
237 |
+
q = self.conv_q(x)
|
238 |
+
k = self.conv_k(c)
|
239 |
+
v = self.conv_v(c)
|
240 |
+
|
241 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
242 |
+
|
243 |
+
x = self.conv_o(x)
|
244 |
+
return x
|
245 |
+
|
246 |
+
def attention(self, query, key, value, mask=None):
|
247 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
248 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
249 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
250 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
251 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
252 |
+
|
253 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
254 |
+
if self.window_size is not None:
|
255 |
+
assert (
|
256 |
+
t_s == t_t
|
257 |
+
), "Relative attention is only available for self-attention."
|
258 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
259 |
+
rel_logits = self._matmul_with_relative_keys(
|
260 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
261 |
+
)
|
262 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
263 |
+
scores = scores + scores_local
|
264 |
+
if self.proximal_bias:
|
265 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
266 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
267 |
+
device=scores.device, dtype=scores.dtype
|
268 |
+
)
|
269 |
+
if mask is not None:
|
270 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
271 |
+
if self.block_length is not None:
|
272 |
+
assert (
|
273 |
+
t_s == t_t
|
274 |
+
), "Local attention is only available for self-attention."
|
275 |
+
block_mask = (
|
276 |
+
torch.ones_like(scores)
|
277 |
+
.triu(-self.block_length)
|
278 |
+
.tril(self.block_length)
|
279 |
+
)
|
280 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
281 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
282 |
+
p_attn = self.drop(p_attn)
|
283 |
+
output = torch.matmul(p_attn, value)
|
284 |
+
if self.window_size is not None:
|
285 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
286 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
287 |
+
self.emb_rel_v, t_s
|
288 |
+
)
|
289 |
+
output = output + self._matmul_with_relative_values(
|
290 |
+
relative_weights, value_relative_embeddings
|
291 |
+
)
|
292 |
+
output = (
|
293 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
294 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
295 |
+
return output, p_attn
|
296 |
+
|
297 |
+
def _matmul_with_relative_values(self, x, y):
|
298 |
+
"""
|
299 |
+
x: [b, h, l, m]
|
300 |
+
y: [h or 1, m, d]
|
301 |
+
ret: [b, h, l, d]
|
302 |
+
"""
|
303 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
304 |
+
return ret
|
305 |
+
|
306 |
+
def _matmul_with_relative_keys(self, x, y):
|
307 |
+
"""
|
308 |
+
x: [b, h, l, d]
|
309 |
+
y: [h or 1, m, d]
|
310 |
+
ret: [b, h, l, m]
|
311 |
+
"""
|
312 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
313 |
+
return ret
|
314 |
+
|
315 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
316 |
+
max_relative_position = 2 * self.window_size + 1
|
317 |
+
# Pad first before slice to avoid using cond ops.
|
318 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
319 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
320 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
321 |
+
if pad_length > 0:
|
322 |
+
padded_relative_embeddings = F.pad(
|
323 |
+
relative_embeddings,
|
324 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
325 |
+
)
|
326 |
+
else:
|
327 |
+
padded_relative_embeddings = relative_embeddings
|
328 |
+
used_relative_embeddings = padded_relative_embeddings[
|
329 |
+
:, slice_start_position:slice_end_position
|
330 |
+
]
|
331 |
+
return used_relative_embeddings
|
332 |
+
|
333 |
+
def _relative_position_to_absolute_position(self, x):
|
334 |
+
"""
|
335 |
+
x: [b, h, l, 2*l-1]
|
336 |
+
ret: [b, h, l, l]
|
337 |
+
"""
|
338 |
+
batch, heads, length, _ = x.size()
|
339 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
340 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
341 |
+
|
342 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
343 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
344 |
+
x_flat = F.pad(
|
345 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
346 |
+
)
|
347 |
+
|
348 |
+
# Reshape and slice out the padded elements.
|
349 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
350 |
+
:, :, :length, length - 1 :
|
351 |
+
]
|
352 |
+
return x_final
|
353 |
+
|
354 |
+
def _absolute_position_to_relative_position(self, x):
|
355 |
+
"""
|
356 |
+
x: [b, h, l, l]
|
357 |
+
ret: [b, h, l, 2*l-1]
|
358 |
+
"""
|
359 |
+
batch, heads, length, _ = x.size()
|
360 |
+
# padd along column
|
361 |
+
x = F.pad(
|
362 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
363 |
+
)
|
364 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
365 |
+
# add 0's in the beginning that will skew the elements after reshape
|
366 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
367 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
368 |
+
return x_final
|
369 |
+
|
370 |
+
def _attention_bias_proximal(self, length):
|
371 |
+
"""Bias for self-attention to encourage attention to close positions.
|
372 |
+
Args:
|
373 |
+
length: an integer scalar.
|
374 |
+
Returns:
|
375 |
+
a Tensor with shape [1, 1, length, length]
|
376 |
+
"""
|
377 |
+
r = torch.arange(length, dtype=torch.float32)
|
378 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
379 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
380 |
+
|
381 |
+
|
382 |
+
class FFN(nn.Module):
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
in_channels,
|
386 |
+
out_channels,
|
387 |
+
filter_channels,
|
388 |
+
kernel_size,
|
389 |
+
p_dropout=0.0,
|
390 |
+
activation=None,
|
391 |
+
causal=False,
|
392 |
+
):
|
393 |
+
super().__init__()
|
394 |
+
self.in_channels = in_channels
|
395 |
+
self.out_channels = out_channels
|
396 |
+
self.filter_channels = filter_channels
|
397 |
+
self.kernel_size = kernel_size
|
398 |
+
self.p_dropout = p_dropout
|
399 |
+
self.activation = activation
|
400 |
+
self.causal = causal
|
401 |
+
|
402 |
+
if causal:
|
403 |
+
self.padding = self._causal_padding
|
404 |
+
else:
|
405 |
+
self.padding = self._same_padding
|
406 |
+
|
407 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
408 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
409 |
+
self.drop = nn.Dropout(p_dropout)
|
410 |
+
|
411 |
+
def forward(self, x, x_mask):
|
412 |
+
x = self.conv_1(self.padding(x * x_mask))
|
413 |
+
if self.activation == "gelu":
|
414 |
+
x = x * torch.sigmoid(1.702 * x)
|
415 |
+
else:
|
416 |
+
x = torch.relu(x)
|
417 |
+
x = self.drop(x)
|
418 |
+
x = self.conv_2(self.padding(x * x_mask))
|
419 |
+
return x * x_mask
|
420 |
+
|
421 |
+
def _causal_padding(self, x):
|
422 |
+
if self.kernel_size == 1:
|
423 |
+
return x
|
424 |
+
pad_l = self.kernel_size - 1
|
425 |
+
pad_r = 0
|
426 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
427 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
428 |
+
return x
|
429 |
+
|
430 |
+
def _same_padding(self, x):
|
431 |
+
if self.kernel_size == 1:
|
432 |
+
return x
|
433 |
+
pad_l = (self.kernel_size - 1) // 2
|
434 |
+
pad_r = self.kernel_size // 2
|
435 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
436 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
437 |
+
return x
|
models/codec/facodec/modules/commons.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os.path
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from munch import Munch
|
15 |
+
import json
|
16 |
+
|
17 |
+
|
18 |
+
class AttrDict(dict):
|
19 |
+
def __init__(self, *args, **kwargs):
|
20 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
21 |
+
self.__dict__ = self
|
22 |
+
|
23 |
+
|
24 |
+
def init_weights(m, mean=0.0, std=0.01):
|
25 |
+
classname = m.__class__.__name__
|
26 |
+
if classname.find("Conv") != -1:
|
27 |
+
m.weight.data.normal_(mean, std)
|
28 |
+
|
29 |
+
|
30 |
+
def get_padding(kernel_size, dilation=1):
|
31 |
+
return int((kernel_size * dilation - dilation) / 2)
|
32 |
+
|
33 |
+
|
34 |
+
def convert_pad_shape(pad_shape):
|
35 |
+
l = pad_shape[::-1]
|
36 |
+
pad_shape = [item for sublist in l for item in sublist]
|
37 |
+
return pad_shape
|
38 |
+
|
39 |
+
|
40 |
+
def intersperse(lst, item):
|
41 |
+
result = [item] * (len(lst) * 2 + 1)
|
42 |
+
result[1::2] = lst
|
43 |
+
return result
|
44 |
+
|
45 |
+
|
46 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
47 |
+
"""KL(P||Q)"""
|
48 |
+
kl = (logs_q - logs_p) - 0.5
|
49 |
+
kl += (
|
50 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
51 |
+
)
|
52 |
+
return kl
|
53 |
+
|
54 |
+
|
55 |
+
def rand_gumbel(shape):
|
56 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
57 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
58 |
+
return -torch.log(-torch.log(uniform_samples))
|
59 |
+
|
60 |
+
|
61 |
+
def rand_gumbel_like(x):
|
62 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
63 |
+
return g
|
64 |
+
|
65 |
+
|
66 |
+
def slice_segments(x, ids_str, segment_size=4):
|
67 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
68 |
+
for i in range(x.size(0)):
|
69 |
+
idx_str = ids_str[i]
|
70 |
+
idx_end = idx_str + segment_size
|
71 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
72 |
+
return ret
|
73 |
+
|
74 |
+
|
75 |
+
def slice_segments_audio(x, ids_str, segment_size=4):
|
76 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
77 |
+
for i in range(x.size(0)):
|
78 |
+
idx_str = ids_str[i]
|
79 |
+
idx_end = idx_str + segment_size
|
80 |
+
ret[i] = x[i, idx_str:idx_end]
|
81 |
+
return ret
|
82 |
+
|
83 |
+
|
84 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
85 |
+
b, d, t = x.size()
|
86 |
+
if x_lengths is None:
|
87 |
+
x_lengths = t
|
88 |
+
ids_str_max = x_lengths - segment_size + 1
|
89 |
+
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
|
90 |
+
dtype=torch.long
|
91 |
+
)
|
92 |
+
ret = slice_segments(x, ids_str, segment_size)
|
93 |
+
return ret, ids_str
|
94 |
+
|
95 |
+
|
96 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
97 |
+
position = torch.arange(length, dtype=torch.float)
|
98 |
+
num_timescales = channels // 2
|
99 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
100 |
+
num_timescales - 1
|
101 |
+
)
|
102 |
+
inv_timescales = min_timescale * torch.exp(
|
103 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
104 |
+
)
|
105 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
106 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
107 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
108 |
+
signal = signal.view(1, channels, length)
|
109 |
+
return signal
|
110 |
+
|
111 |
+
|
112 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
113 |
+
b, channels, length = x.size()
|
114 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
115 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
116 |
+
|
117 |
+
|
118 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
119 |
+
b, channels, length = x.size()
|
120 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
121 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
122 |
+
|
123 |
+
|
124 |
+
def subsequent_mask(length):
|
125 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
126 |
+
return mask
|
127 |
+
|
128 |
+
|
129 |
+
@torch.jit.script
|
130 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
131 |
+
n_channels_int = n_channels[0]
|
132 |
+
in_act = input_a + input_b
|
133 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
134 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
135 |
+
acts = t_act * s_act
|
136 |
+
return acts
|
137 |
+
|
138 |
+
|
139 |
+
def convert_pad_shape(pad_shape):
|
140 |
+
l = pad_shape[::-1]
|
141 |
+
pad_shape = [item for sublist in l for item in sublist]
|
142 |
+
return pad_shape
|
143 |
+
|
144 |
+
|
145 |
+
def shift_1d(x):
|
146 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
def sequence_mask(length, max_length=None):
|
151 |
+
if max_length is None:
|
152 |
+
max_length = length.max()
|
153 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
154 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
155 |
+
|
156 |
+
|
157 |
+
def generate_path(duration, mask):
|
158 |
+
"""
|
159 |
+
duration: [b, 1, t_x]
|
160 |
+
mask: [b, 1, t_y, t_x]
|
161 |
+
"""
|
162 |
+
device = duration.device
|
163 |
+
|
164 |
+
b, _, t_y, t_x = mask.shape
|
165 |
+
cum_duration = torch.cumsum(duration, -1)
|
166 |
+
|
167 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
168 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
169 |
+
path = path.view(b, t_x, t_y)
|
170 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
171 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
172 |
+
return path
|
173 |
+
|
174 |
+
|
175 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
176 |
+
if isinstance(parameters, torch.Tensor):
|
177 |
+
parameters = [parameters]
|
178 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
179 |
+
norm_type = float(norm_type)
|
180 |
+
if clip_value is not None:
|
181 |
+
clip_value = float(clip_value)
|
182 |
+
|
183 |
+
total_norm = 0
|
184 |
+
for p in parameters:
|
185 |
+
param_norm = p.grad.data.norm(norm_type)
|
186 |
+
total_norm += param_norm.item() ** norm_type
|
187 |
+
if clip_value is not None:
|
188 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
189 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
190 |
+
return total_norm
|
191 |
+
|
192 |
+
|
193 |
+
def log_norm(x, mean=-4, std=4, dim=2):
|
194 |
+
"""
|
195 |
+
normalized log mel -> mel -> norm -> log(norm)
|
196 |
+
"""
|
197 |
+
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
|
198 |
+
return x
|
199 |
+
|
200 |
+
|
201 |
+
from huggingface_hub import hf_hub_download
|
202 |
+
|
203 |
+
|
204 |
+
def load_F0_models(path):
|
205 |
+
# load F0 model
|
206 |
+
from .JDC.model import JDCNet
|
207 |
+
|
208 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
209 |
+
if not os.path.exists(path):
|
210 |
+
path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7")
|
211 |
+
params = torch.load(path, map_location="cpu")["net"]
|
212 |
+
F0_model.load_state_dict(params)
|
213 |
+
_ = F0_model.train()
|
214 |
+
|
215 |
+
return F0_model
|
216 |
+
|
217 |
+
|
218 |
+
# Generators
|
219 |
+
from modules.dac.model.dac import Encoder, Decoder
|
220 |
+
from .quantize import FAquantizer, FApredictors
|
221 |
+
|
222 |
+
# Discriminators
|
223 |
+
from modules.dac.model.discriminator import Discriminator
|
224 |
+
|
225 |
+
|
226 |
+
def build_model(args):
|
227 |
+
encoder = Encoder(
|
228 |
+
d_model=args.DAC.encoder_dim,
|
229 |
+
strides=args.DAC.encoder_rates,
|
230 |
+
d_latent=1024,
|
231 |
+
causal=args.causal,
|
232 |
+
lstm=args.lstm,
|
233 |
+
)
|
234 |
+
|
235 |
+
quantizer = FAquantizer(
|
236 |
+
in_dim=1024,
|
237 |
+
n_p_codebooks=1,
|
238 |
+
n_c_codebooks=args.n_c_codebooks,
|
239 |
+
n_t_codebooks=2,
|
240 |
+
n_r_codebooks=3,
|
241 |
+
codebook_size=1024,
|
242 |
+
codebook_dim=8,
|
243 |
+
quantizer_dropout=0.5,
|
244 |
+
causal=args.causal,
|
245 |
+
separate_prosody_encoder=args.separate_prosody_encoder,
|
246 |
+
timbre_norm=args.timbre_norm,
|
247 |
+
)
|
248 |
+
|
249 |
+
fa_predictors = FApredictors(
|
250 |
+
in_dim=1024,
|
251 |
+
use_gr_content_f0=args.use_gr_content_f0,
|
252 |
+
use_gr_prosody_phone=args.use_gr_prosody_phone,
|
253 |
+
use_gr_residual_f0=True,
|
254 |
+
use_gr_residual_phone=True,
|
255 |
+
use_gr_timbre_content=True,
|
256 |
+
use_gr_timbre_prosody=args.use_gr_timbre_prosody,
|
257 |
+
use_gr_x_timbre=True,
|
258 |
+
norm_f0=args.norm_f0,
|
259 |
+
timbre_norm=args.timbre_norm,
|
260 |
+
use_gr_content_global_f0=args.use_gr_content_global_f0,
|
261 |
+
)
|
262 |
+
|
263 |
+
decoder = Decoder(
|
264 |
+
input_channel=1024,
|
265 |
+
channels=args.DAC.decoder_dim,
|
266 |
+
rates=args.DAC.decoder_rates,
|
267 |
+
causal=args.causal,
|
268 |
+
lstm=args.lstm,
|
269 |
+
)
|
270 |
+
|
271 |
+
discriminator = Discriminator(
|
272 |
+
rates=[],
|
273 |
+
periods=[2, 3, 5, 7, 11],
|
274 |
+
fft_sizes=[2048, 1024, 512],
|
275 |
+
sample_rate=args.DAC.sr,
|
276 |
+
bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
|
277 |
+
)
|
278 |
+
|
279 |
+
nets = Munch(
|
280 |
+
encoder=encoder,
|
281 |
+
quantizer=quantizer,
|
282 |
+
decoder=decoder,
|
283 |
+
discriminator=discriminator,
|
284 |
+
fa_predictors=fa_predictors,
|
285 |
+
)
|
286 |
+
|
287 |
+
return nets
|
288 |
+
|
289 |
+
|
290 |
+
def load_checkpoint(
|
291 |
+
model,
|
292 |
+
optimizer,
|
293 |
+
path,
|
294 |
+
load_only_params=True,
|
295 |
+
ignore_modules=[],
|
296 |
+
is_distributed=False,
|
297 |
+
):
|
298 |
+
state = torch.load(path, map_location="cpu")
|
299 |
+
params = state["net"]
|
300 |
+
for key in model:
|
301 |
+
if key in params and key not in ignore_modules:
|
302 |
+
if not is_distributed:
|
303 |
+
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
304 |
+
for k in list(params[key].keys()):
|
305 |
+
if k.startswith("module."):
|
306 |
+
params[key][k[len("module.") :]] = params[key][k]
|
307 |
+
del params[key][k]
|
308 |
+
print("%s loaded" % key)
|
309 |
+
model[key].load_state_dict(params[key], strict=True)
|
310 |
+
_ = [model[key].eval() for key in model]
|
311 |
+
|
312 |
+
if not load_only_params:
|
313 |
+
epoch = state["epoch"] + 1
|
314 |
+
iters = state["iters"]
|
315 |
+
optimizer.load_state_dict(state["optimizer"])
|
316 |
+
optimizer.load_scheduler_state_dict(state["scheduler"])
|
317 |
+
|
318 |
+
else:
|
319 |
+
epoch = state["epoch"] + 1
|
320 |
+
iters = state["iters"]
|
321 |
+
|
322 |
+
return model, optimizer, epoch, iters
|
323 |
+
|
324 |
+
|
325 |
+
def recursive_munch(d):
|
326 |
+
if isinstance(d, dict):
|
327 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
328 |
+
elif isinstance(d, list):
|
329 |
+
return [recursive_munch(v) for v in d]
|
330 |
+
else:
|
331 |
+
return d
|
models/codec/facodec/modules/gradient_reversal.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from torch.autograd import Function
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
|
11 |
+
class GradientReversal(Function):
|
12 |
+
@staticmethod
|
13 |
+
def forward(ctx, x, alpha):
|
14 |
+
ctx.save_for_backward(x, alpha)
|
15 |
+
return x
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def backward(ctx, grad_output):
|
19 |
+
grad_input = None
|
20 |
+
_, alpha = ctx.saved_tensors
|
21 |
+
if ctx.needs_input_grad[0]:
|
22 |
+
grad_input = -alpha * grad_output
|
23 |
+
return grad_input, None
|
24 |
+
|
25 |
+
|
26 |
+
revgrad = GradientReversal.apply
|
27 |
+
|
28 |
+
|
29 |
+
class GradientReversal(nn.Module):
|
30 |
+
def __init__(self, alpha):
|
31 |
+
super().__init__()
|
32 |
+
self.alpha = torch.tensor(alpha, requires_grad=False)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return revgrad(x, self.alpha)
|
models/codec/facodec/modules/layers.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from typing import Optional, Any
|
10 |
+
from torch import Tensor
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torchaudio
|
13 |
+
import torchaudio.functional as audio_F
|
14 |
+
|
15 |
+
import random
|
16 |
+
|
17 |
+
random.seed(0)
|
18 |
+
|
19 |
+
|
20 |
+
def _get_activation_fn(activ):
|
21 |
+
if activ == "relu":
|
22 |
+
return nn.ReLU()
|
23 |
+
elif activ == "lrelu":
|
24 |
+
return nn.LeakyReLU(0.2)
|
25 |
+
elif activ == "swish":
|
26 |
+
return lambda x: x * torch.sigmoid(x)
|
27 |
+
else:
|
28 |
+
raise RuntimeError(
|
29 |
+
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class LinearNorm(torch.nn.Module):
|
34 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
|
35 |
+
super(LinearNorm, self).__init__()
|
36 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
37 |
+
|
38 |
+
torch.nn.init.xavier_uniform_(
|
39 |
+
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return self.linear_layer(x)
|
44 |
+
|
45 |
+
|
46 |
+
class ConvNorm(torch.nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
in_channels,
|
50 |
+
out_channels,
|
51 |
+
kernel_size=1,
|
52 |
+
stride=1,
|
53 |
+
padding=None,
|
54 |
+
dilation=1,
|
55 |
+
bias=True,
|
56 |
+
w_init_gain="linear",
|
57 |
+
param=None,
|
58 |
+
):
|
59 |
+
super(ConvNorm, self).__init__()
|
60 |
+
if padding is None:
|
61 |
+
assert kernel_size % 2 == 1
|
62 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
63 |
+
|
64 |
+
self.conv = torch.nn.Conv1d(
|
65 |
+
in_channels,
|
66 |
+
out_channels,
|
67 |
+
kernel_size=kernel_size,
|
68 |
+
stride=stride,
|
69 |
+
padding=padding,
|
70 |
+
dilation=dilation,
|
71 |
+
bias=bias,
|
72 |
+
)
|
73 |
+
|
74 |
+
torch.nn.init.xavier_uniform_(
|
75 |
+
self.conv.weight,
|
76 |
+
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, signal):
|
80 |
+
conv_signal = self.conv(signal)
|
81 |
+
return conv_signal
|
82 |
+
|
83 |
+
|
84 |
+
class CausualConv(nn.Module):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
in_channels,
|
88 |
+
out_channels,
|
89 |
+
kernel_size=1,
|
90 |
+
stride=1,
|
91 |
+
padding=1,
|
92 |
+
dilation=1,
|
93 |
+
bias=True,
|
94 |
+
w_init_gain="linear",
|
95 |
+
param=None,
|
96 |
+
):
|
97 |
+
super(CausualConv, self).__init__()
|
98 |
+
if padding is None:
|
99 |
+
assert kernel_size % 2 == 1
|
100 |
+
padding = int(dilation * (kernel_size - 1) / 2) * 2
|
101 |
+
else:
|
102 |
+
self.padding = padding * 2
|
103 |
+
self.conv = nn.Conv1d(
|
104 |
+
in_channels,
|
105 |
+
out_channels,
|
106 |
+
kernel_size=kernel_size,
|
107 |
+
stride=stride,
|
108 |
+
padding=self.padding,
|
109 |
+
dilation=dilation,
|
110 |
+
bias=bias,
|
111 |
+
)
|
112 |
+
|
113 |
+
torch.nn.init.xavier_uniform_(
|
114 |
+
self.conv.weight,
|
115 |
+
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
|
116 |
+
)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
x = self.conv(x)
|
120 |
+
x = x[:, :, : -self.padding]
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class CausualBlock(nn.Module):
|
125 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
|
126 |
+
super(CausualBlock, self).__init__()
|
127 |
+
self.blocks = nn.ModuleList(
|
128 |
+
[
|
129 |
+
self._get_conv(
|
130 |
+
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
131 |
+
)
|
132 |
+
for i in range(n_conv)
|
133 |
+
]
|
134 |
+
)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
for block in self.blocks:
|
138 |
+
res = x
|
139 |
+
x = block(x)
|
140 |
+
x += res
|
141 |
+
return x
|
142 |
+
|
143 |
+
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
|
144 |
+
layers = [
|
145 |
+
CausualConv(
|
146 |
+
hidden_dim,
|
147 |
+
hidden_dim,
|
148 |
+
kernel_size=3,
|
149 |
+
padding=dilation,
|
150 |
+
dilation=dilation,
|
151 |
+
),
|
152 |
+
_get_activation_fn(activ),
|
153 |
+
nn.BatchNorm1d(hidden_dim),
|
154 |
+
nn.Dropout(p=dropout_p),
|
155 |
+
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
156 |
+
_get_activation_fn(activ),
|
157 |
+
nn.Dropout(p=dropout_p),
|
158 |
+
]
|
159 |
+
return nn.Sequential(*layers)
|
160 |
+
|
161 |
+
|
162 |
+
class ConvBlock(nn.Module):
|
163 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
|
164 |
+
super().__init__()
|
165 |
+
self._n_groups = 8
|
166 |
+
self.blocks = nn.ModuleList(
|
167 |
+
[
|
168 |
+
self._get_conv(
|
169 |
+
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
|
170 |
+
)
|
171 |
+
for i in range(n_conv)
|
172 |
+
]
|
173 |
+
)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
for block in self.blocks:
|
177 |
+
res = x
|
178 |
+
x = block(x)
|
179 |
+
x += res
|
180 |
+
return x
|
181 |
+
|
182 |
+
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
|
183 |
+
layers = [
|
184 |
+
ConvNorm(
|
185 |
+
hidden_dim,
|
186 |
+
hidden_dim,
|
187 |
+
kernel_size=3,
|
188 |
+
padding=dilation,
|
189 |
+
dilation=dilation,
|
190 |
+
),
|
191 |
+
_get_activation_fn(activ),
|
192 |
+
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
193 |
+
nn.Dropout(p=dropout_p),
|
194 |
+
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
195 |
+
_get_activation_fn(activ),
|
196 |
+
nn.Dropout(p=dropout_p),
|
197 |
+
]
|
198 |
+
return nn.Sequential(*layers)
|
199 |
+
|
200 |
+
|
201 |
+
class LocationLayer(nn.Module):
|
202 |
+
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
|
203 |
+
super(LocationLayer, self).__init__()
|
204 |
+
padding = int((attention_kernel_size - 1) / 2)
|
205 |
+
self.location_conv = ConvNorm(
|
206 |
+
2,
|
207 |
+
attention_n_filters,
|
208 |
+
kernel_size=attention_kernel_size,
|
209 |
+
padding=padding,
|
210 |
+
bias=False,
|
211 |
+
stride=1,
|
212 |
+
dilation=1,
|
213 |
+
)
|
214 |
+
self.location_dense = LinearNorm(
|
215 |
+
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(self, attention_weights_cat):
|
219 |
+
processed_attention = self.location_conv(attention_weights_cat)
|
220 |
+
processed_attention = processed_attention.transpose(1, 2)
|
221 |
+
processed_attention = self.location_dense(processed_attention)
|
222 |
+
return processed_attention
|
223 |
+
|
224 |
+
|
225 |
+
class Attention(nn.Module):
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
attention_rnn_dim,
|
229 |
+
embedding_dim,
|
230 |
+
attention_dim,
|
231 |
+
attention_location_n_filters,
|
232 |
+
attention_location_kernel_size,
|
233 |
+
):
|
234 |
+
super(Attention, self).__init__()
|
235 |
+
self.query_layer = LinearNorm(
|
236 |
+
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
237 |
+
)
|
238 |
+
self.memory_layer = LinearNorm(
|
239 |
+
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
240 |
+
)
|
241 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
242 |
+
self.location_layer = LocationLayer(
|
243 |
+
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
244 |
+
)
|
245 |
+
self.score_mask_value = -float("inf")
|
246 |
+
|
247 |
+
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
248 |
+
"""
|
249 |
+
PARAMS
|
250 |
+
------
|
251 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
252 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
253 |
+
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
254 |
+
RETURNS
|
255 |
+
-------
|
256 |
+
alignment (batch, max_time)
|
257 |
+
"""
|
258 |
+
|
259 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
260 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
261 |
+
energies = self.v(
|
262 |
+
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
263 |
+
)
|
264 |
+
|
265 |
+
energies = energies.squeeze(-1)
|
266 |
+
return energies
|
267 |
+
|
268 |
+
def forward(
|
269 |
+
self,
|
270 |
+
attention_hidden_state,
|
271 |
+
memory,
|
272 |
+
processed_memory,
|
273 |
+
attention_weights_cat,
|
274 |
+
mask,
|
275 |
+
):
|
276 |
+
"""
|
277 |
+
PARAMS
|
278 |
+
------
|
279 |
+
attention_hidden_state: attention rnn last output
|
280 |
+
memory: encoder outputs
|
281 |
+
processed_memory: processed encoder outputs
|
282 |
+
attention_weights_cat: previous and cummulative attention weights
|
283 |
+
mask: binary mask for padded data
|
284 |
+
"""
|
285 |
+
alignment = self.get_alignment_energies(
|
286 |
+
attention_hidden_state, processed_memory, attention_weights_cat
|
287 |
+
)
|
288 |
+
|
289 |
+
if mask is not None:
|
290 |
+
alignment.data.masked_fill_(mask, self.score_mask_value)
|
291 |
+
|
292 |
+
attention_weights = F.softmax(alignment, dim=1)
|
293 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
294 |
+
attention_context = attention_context.squeeze(1)
|
295 |
+
|
296 |
+
return attention_context, attention_weights
|
297 |
+
|
298 |
+
|
299 |
+
class ForwardAttentionV2(nn.Module):
|
300 |
+
def __init__(
|
301 |
+
self,
|
302 |
+
attention_rnn_dim,
|
303 |
+
embedding_dim,
|
304 |
+
attention_dim,
|
305 |
+
attention_location_n_filters,
|
306 |
+
attention_location_kernel_size,
|
307 |
+
):
|
308 |
+
super(ForwardAttentionV2, self).__init__()
|
309 |
+
self.query_layer = LinearNorm(
|
310 |
+
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
|
311 |
+
)
|
312 |
+
self.memory_layer = LinearNorm(
|
313 |
+
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
|
314 |
+
)
|
315 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
316 |
+
self.location_layer = LocationLayer(
|
317 |
+
attention_location_n_filters, attention_location_kernel_size, attention_dim
|
318 |
+
)
|
319 |
+
self.score_mask_value = -float(1e20)
|
320 |
+
|
321 |
+
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
|
322 |
+
"""
|
323 |
+
PARAMS
|
324 |
+
------
|
325 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
326 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
327 |
+
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
|
328 |
+
RETURNS
|
329 |
+
-------
|
330 |
+
alignment (batch, max_time)
|
331 |
+
"""
|
332 |
+
|
333 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
334 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
335 |
+
energies = self.v(
|
336 |
+
torch.tanh(processed_query + processed_attention_weights + processed_memory)
|
337 |
+
)
|
338 |
+
|
339 |
+
energies = energies.squeeze(-1)
|
340 |
+
return energies
|
341 |
+
|
342 |
+
def forward(
|
343 |
+
self,
|
344 |
+
attention_hidden_state,
|
345 |
+
memory,
|
346 |
+
processed_memory,
|
347 |
+
attention_weights_cat,
|
348 |
+
mask,
|
349 |
+
log_alpha,
|
350 |
+
):
|
351 |
+
"""
|
352 |
+
PARAMS
|
353 |
+
------
|
354 |
+
attention_hidden_state: attention rnn last output
|
355 |
+
memory: encoder outputs
|
356 |
+
processed_memory: processed encoder outputs
|
357 |
+
attention_weights_cat: previous and cummulative attention weights
|
358 |
+
mask: binary mask for padded data
|
359 |
+
"""
|
360 |
+
log_energy = self.get_alignment_energies(
|
361 |
+
attention_hidden_state, processed_memory, attention_weights_cat
|
362 |
+
)
|
363 |
+
|
364 |
+
# log_energy =
|
365 |
+
|
366 |
+
if mask is not None:
|
367 |
+
log_energy.data.masked_fill_(mask, self.score_mask_value)
|
368 |
+
|
369 |
+
# attention_weights = F.softmax(alignment, dim=1)
|
370 |
+
|
371 |
+
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
|
372 |
+
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
|
373 |
+
|
374 |
+
# log_total_score = log_alpha + content_score
|
375 |
+
|
376 |
+
# previous_attention_weights = attention_weights_cat[:,0,:]
|
377 |
+
|
378 |
+
log_alpha_shift_padded = []
|
379 |
+
max_time = log_energy.size(1)
|
380 |
+
for sft in range(2):
|
381 |
+
shifted = log_alpha[:, : max_time - sft]
|
382 |
+
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
|
383 |
+
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
|
384 |
+
|
385 |
+
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
|
386 |
+
|
387 |
+
log_alpha_new = biased + log_energy
|
388 |
+
|
389 |
+
attention_weights = F.softmax(log_alpha_new, dim=1)
|
390 |
+
|
391 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
392 |
+
attention_context = attention_context.squeeze(1)
|
393 |
+
|
394 |
+
return attention_context, attention_weights, log_alpha_new
|
395 |
+
|
396 |
+
|
397 |
+
class PhaseShuffle2d(nn.Module):
|
398 |
+
def __init__(self, n=2):
|
399 |
+
super(PhaseShuffle2d, self).__init__()
|
400 |
+
self.n = n
|
401 |
+
self.random = random.Random(1)
|
402 |
+
|
403 |
+
def forward(self, x, move=None):
|
404 |
+
# x.size = (B, C, M, L)
|
405 |
+
if move is None:
|
406 |
+
move = self.random.randint(-self.n, self.n)
|
407 |
+
|
408 |
+
if move == 0:
|
409 |
+
return x
|
410 |
+
else:
|
411 |
+
left = x[:, :, :, :move]
|
412 |
+
right = x[:, :, :, move:]
|
413 |
+
shuffled = torch.cat([right, left], dim=3)
|
414 |
+
return shuffled
|
415 |
+
|
416 |
+
|
417 |
+
class PhaseShuffle1d(nn.Module):
|
418 |
+
def __init__(self, n=2):
|
419 |
+
super(PhaseShuffle1d, self).__init__()
|
420 |
+
self.n = n
|
421 |
+
self.random = random.Random(1)
|
422 |
+
|
423 |
+
def forward(self, x, move=None):
|
424 |
+
# x.size = (B, C, M, L)
|
425 |
+
if move is None:
|
426 |
+
move = self.random.randint(-self.n, self.n)
|
427 |
+
|
428 |
+
if move == 0:
|
429 |
+
return x
|
430 |
+
else:
|
431 |
+
left = x[:, :, :move]
|
432 |
+
right = x[:, :, move:]
|
433 |
+
shuffled = torch.cat([right, left], dim=2)
|
434 |
+
|
435 |
+
return shuffled
|
436 |
+
|
437 |
+
|
438 |
+
class MFCC(nn.Module):
|
439 |
+
def __init__(self, n_mfcc=40, n_mels=80):
|
440 |
+
super(MFCC, self).__init__()
|
441 |
+
self.n_mfcc = n_mfcc
|
442 |
+
self.n_mels = n_mels
|
443 |
+
self.norm = "ortho"
|
444 |
+
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
445 |
+
self.register_buffer("dct_mat", dct_mat)
|
446 |
+
|
447 |
+
def forward(self, mel_specgram):
|
448 |
+
if len(mel_specgram.shape) == 2:
|
449 |
+
mel_specgram = mel_specgram.unsqueeze(0)
|
450 |
+
unsqueezed = True
|
451 |
+
else:
|
452 |
+
unsqueezed = False
|
453 |
+
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
454 |
+
# -> (channel, time, n_mfcc).tranpose(...)
|
455 |
+
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
456 |
+
|
457 |
+
# unpack batch
|
458 |
+
if unsqueezed:
|
459 |
+
mfcc = mfcc.squeeze(0)
|
460 |
+
return mfcc
|
models/codec/facodec/modules/quantize.py
ADDED
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from modules.dac.nn.quantize import ResidualVectorQuantize
|
7 |
+
from torch import nn
|
8 |
+
from .wavenet import WN
|
9 |
+
from .style_encoder import StyleEncoder
|
10 |
+
from .gradient_reversal import GradientReversal
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
import torchaudio.functional as audio_F
|
14 |
+
import numpy as np
|
15 |
+
from ..alias_free_torch import *
|
16 |
+
from torch.nn.utils import weight_norm
|
17 |
+
from torch import nn, sin, pow
|
18 |
+
from einops.layers.torch import Rearrange
|
19 |
+
from modules.dac.model.encodec import SConv1d
|
20 |
+
|
21 |
+
|
22 |
+
def init_weights(m):
|
23 |
+
if isinstance(m, nn.Conv1d):
|
24 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
25 |
+
nn.init.constant_(m.bias, 0)
|
26 |
+
|
27 |
+
|
28 |
+
def WNConv1d(*args, **kwargs):
|
29 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
30 |
+
|
31 |
+
|
32 |
+
def WNConvTranspose1d(*args, **kwargs):
|
33 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
34 |
+
|
35 |
+
|
36 |
+
class SnakeBeta(nn.Module):
|
37 |
+
"""
|
38 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
39 |
+
Shape:
|
40 |
+
- Input: (B, C, T)
|
41 |
+
- Output: (B, C, T), same shape as the input
|
42 |
+
Parameters:
|
43 |
+
- alpha - trainable parameter that controls frequency
|
44 |
+
- beta - trainable parameter that controls magnitude
|
45 |
+
References:
|
46 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
47 |
+
https://arxiv.org/abs/2006.08195
|
48 |
+
Examples:
|
49 |
+
>>> a1 = snakebeta(256)
|
50 |
+
>>> x = torch.randn(256)
|
51 |
+
>>> x = a1(x)
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Initialization.
|
59 |
+
INPUT:
|
60 |
+
- in_features: shape of the input
|
61 |
+
- alpha - trainable parameter that controls frequency
|
62 |
+
- beta - trainable parameter that controls magnitude
|
63 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
64 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
65 |
+
alpha will be trained along with the rest of your model.
|
66 |
+
"""
|
67 |
+
super(SnakeBeta, self).__init__()
|
68 |
+
self.in_features = in_features
|
69 |
+
|
70 |
+
# initialize alpha
|
71 |
+
self.alpha_logscale = alpha_logscale
|
72 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
73 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
74 |
+
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
75 |
+
else: # linear scale alphas initialized to ones
|
76 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
77 |
+
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
78 |
+
|
79 |
+
self.alpha.requires_grad = alpha_trainable
|
80 |
+
self.beta.requires_grad = alpha_trainable
|
81 |
+
|
82 |
+
self.no_div_by_zero = 0.000000001
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
"""
|
86 |
+
Forward pass of the function.
|
87 |
+
Applies the function to the input elementwise.
|
88 |
+
SnakeBeta := x + 1/b * sin^2 (xa)
|
89 |
+
"""
|
90 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
91 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
92 |
+
if self.alpha_logscale:
|
93 |
+
alpha = torch.exp(alpha)
|
94 |
+
beta = torch.exp(beta)
|
95 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
96 |
+
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class ResidualUnit(nn.Module):
|
101 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
102 |
+
super().__init__()
|
103 |
+
pad = ((7 - 1) * dilation) // 2
|
104 |
+
self.block = nn.Sequential(
|
105 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
106 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
107 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
108 |
+
WNConv1d(dim, dim, kernel_size=1),
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
return x + self.block(x)
|
113 |
+
|
114 |
+
|
115 |
+
class CNNLSTM(nn.Module):
|
116 |
+
def __init__(self, indim, outdim, head, global_pred=False):
|
117 |
+
super().__init__()
|
118 |
+
self.global_pred = global_pred
|
119 |
+
self.model = nn.Sequential(
|
120 |
+
ResidualUnit(indim, dilation=1),
|
121 |
+
ResidualUnit(indim, dilation=2),
|
122 |
+
ResidualUnit(indim, dilation=3),
|
123 |
+
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
|
124 |
+
Rearrange("b c t -> b t c"),
|
125 |
+
)
|
126 |
+
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
# x: [B, C, T]
|
130 |
+
x = self.model(x)
|
131 |
+
if self.global_pred:
|
132 |
+
x = torch.mean(x, dim=1, keepdim=False)
|
133 |
+
outs = [head(x) for head in self.heads]
|
134 |
+
return outs
|
135 |
+
|
136 |
+
|
137 |
+
def sequence_mask(length, max_length=None):
|
138 |
+
if max_length is None:
|
139 |
+
max_length = length.max()
|
140 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
141 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
142 |
+
|
143 |
+
|
144 |
+
class MFCC(nn.Module):
|
145 |
+
def __init__(self, n_mfcc=40, n_mels=80):
|
146 |
+
super(MFCC, self).__init__()
|
147 |
+
self.n_mfcc = n_mfcc
|
148 |
+
self.n_mels = n_mels
|
149 |
+
self.norm = "ortho"
|
150 |
+
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
151 |
+
self.register_buffer("dct_mat", dct_mat)
|
152 |
+
|
153 |
+
def forward(self, mel_specgram):
|
154 |
+
if len(mel_specgram.shape) == 2:
|
155 |
+
mel_specgram = mel_specgram.unsqueeze(0)
|
156 |
+
unsqueezed = True
|
157 |
+
else:
|
158 |
+
unsqueezed = False
|
159 |
+
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
160 |
+
# -> (channel, time, n_mfcc).tranpose(...)
|
161 |
+
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
162 |
+
|
163 |
+
# unpack batch
|
164 |
+
if unsqueezed:
|
165 |
+
mfcc = mfcc.squeeze(0)
|
166 |
+
return mfcc
|
167 |
+
|
168 |
+
|
169 |
+
class FAquantizer(nn.Module):
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
in_dim=1024,
|
173 |
+
n_p_codebooks=1,
|
174 |
+
n_c_codebooks=2,
|
175 |
+
n_t_codebooks=2,
|
176 |
+
n_r_codebooks=3,
|
177 |
+
codebook_size=1024,
|
178 |
+
codebook_dim=8,
|
179 |
+
quantizer_dropout=0.5,
|
180 |
+
causal=False,
|
181 |
+
separate_prosody_encoder=False,
|
182 |
+
timbre_norm=False,
|
183 |
+
):
|
184 |
+
super(FAquantizer, self).__init__()
|
185 |
+
conv1d_type = SConv1d # if causal else nn.Conv1d
|
186 |
+
self.prosody_quantizer = ResidualVectorQuantize(
|
187 |
+
input_dim=in_dim,
|
188 |
+
n_codebooks=n_p_codebooks,
|
189 |
+
codebook_size=codebook_size,
|
190 |
+
codebook_dim=codebook_dim,
|
191 |
+
quantizer_dropout=quantizer_dropout,
|
192 |
+
)
|
193 |
+
|
194 |
+
self.content_quantizer = ResidualVectorQuantize(
|
195 |
+
input_dim=in_dim,
|
196 |
+
n_codebooks=n_c_codebooks,
|
197 |
+
codebook_size=codebook_size,
|
198 |
+
codebook_dim=codebook_dim,
|
199 |
+
quantizer_dropout=quantizer_dropout,
|
200 |
+
)
|
201 |
+
|
202 |
+
if not timbre_norm:
|
203 |
+
self.timbre_quantizer = ResidualVectorQuantize(
|
204 |
+
input_dim=in_dim,
|
205 |
+
n_codebooks=n_t_codebooks,
|
206 |
+
codebook_size=codebook_size,
|
207 |
+
codebook_dim=codebook_dim,
|
208 |
+
quantizer_dropout=quantizer_dropout,
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
self.timbre_encoder = StyleEncoder(
|
212 |
+
in_dim=80, hidden_dim=512, out_dim=in_dim
|
213 |
+
)
|
214 |
+
self.timbre_linear = nn.Linear(1024, 1024 * 2)
|
215 |
+
self.timbre_linear.bias.data[:1024] = 1
|
216 |
+
self.timbre_linear.bias.data[1024:] = 0
|
217 |
+
self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
|
218 |
+
|
219 |
+
self.residual_quantizer = ResidualVectorQuantize(
|
220 |
+
input_dim=in_dim,
|
221 |
+
n_codebooks=n_r_codebooks,
|
222 |
+
codebook_size=codebook_size,
|
223 |
+
codebook_dim=codebook_dim,
|
224 |
+
quantizer_dropout=quantizer_dropout,
|
225 |
+
)
|
226 |
+
|
227 |
+
if separate_prosody_encoder:
|
228 |
+
self.melspec_linear = conv1d_type(
|
229 |
+
in_channels=20, out_channels=256, kernel_size=1, causal=causal
|
230 |
+
)
|
231 |
+
self.melspec_encoder = WN(
|
232 |
+
hidden_channels=256,
|
233 |
+
kernel_size=5,
|
234 |
+
dilation_rate=1,
|
235 |
+
n_layers=8,
|
236 |
+
gin_channels=0,
|
237 |
+
p_dropout=0.2,
|
238 |
+
causal=causal,
|
239 |
+
)
|
240 |
+
self.melspec_linear2 = conv1d_type(
|
241 |
+
in_channels=256, out_channels=1024, kernel_size=1, causal=causal
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
pass
|
245 |
+
self.separate_prosody_encoder = separate_prosody_encoder
|
246 |
+
|
247 |
+
self.prob_random_mask_residual = 0.75
|
248 |
+
|
249 |
+
SPECT_PARAMS = {
|
250 |
+
"n_fft": 2048,
|
251 |
+
"win_length": 1200,
|
252 |
+
"hop_length": 300,
|
253 |
+
}
|
254 |
+
MEL_PARAMS = {
|
255 |
+
"n_mels": 80,
|
256 |
+
}
|
257 |
+
|
258 |
+
self.to_mel = torchaudio.transforms.MelSpectrogram(
|
259 |
+
n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
|
260 |
+
)
|
261 |
+
self.mel_mean, self.mel_std = -4, 4
|
262 |
+
self.frame_rate = 24000 / 300
|
263 |
+
self.hop_length = 300
|
264 |
+
|
265 |
+
self.is_timbre_norm = timbre_norm
|
266 |
+
if timbre_norm:
|
267 |
+
self.forward = self.forward_v2
|
268 |
+
|
269 |
+
def preprocess(self, wave_tensor, n_bins=20):
|
270 |
+
mel_tensor = self.to_mel(wave_tensor.squeeze(1))
|
271 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
|
272 |
+
return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)]
|
273 |
+
|
274 |
+
@torch.no_grad()
|
275 |
+
def decode(self, codes):
|
276 |
+
code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
|
277 |
+
|
278 |
+
z_c = self.content_quantizer.from_codes(code_c)[0]
|
279 |
+
z_p = self.prosody_quantizer.from_codes(code_p)[0]
|
280 |
+
z_t = self.timbre_quantizer.from_codes(code_t)[0]
|
281 |
+
|
282 |
+
z = z_c + z_p + z_t
|
283 |
+
|
284 |
+
return z, [z_c, z_p, z_t]
|
285 |
+
|
286 |
+
@torch.no_grad()
|
287 |
+
def encode(self, x, wave_segments, n_c=1):
|
288 |
+
outs = 0
|
289 |
+
if self.separate_prosody_encoder:
|
290 |
+
prosody_feature = self.preprocess(wave_segments)
|
291 |
+
|
292 |
+
f0_input = prosody_feature # (B, T, 20)
|
293 |
+
f0_input = self.melspec_linear(f0_input)
|
294 |
+
f0_input = self.melspec_encoder(
|
295 |
+
f0_input,
|
296 |
+
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
|
297 |
+
.to(f0_input.device)
|
298 |
+
.bool(),
|
299 |
+
)
|
300 |
+
f0_input = self.melspec_linear2(f0_input)
|
301 |
+
|
302 |
+
common_min_size = min(f0_input.size(2), x.size(2))
|
303 |
+
f0_input = f0_input[:, :, :common_min_size]
|
304 |
+
|
305 |
+
x = x[:, :, :common_min_size]
|
306 |
+
|
307 |
+
(
|
308 |
+
z_p,
|
309 |
+
codes_p,
|
310 |
+
latents_p,
|
311 |
+
commitment_loss_p,
|
312 |
+
codebook_loss_p,
|
313 |
+
) = self.prosody_quantizer(f0_input, 1)
|
314 |
+
outs += z_p.detach()
|
315 |
+
else:
|
316 |
+
(
|
317 |
+
z_p,
|
318 |
+
codes_p,
|
319 |
+
latents_p,
|
320 |
+
commitment_loss_p,
|
321 |
+
codebook_loss_p,
|
322 |
+
) = self.prosody_quantizer(x, 1)
|
323 |
+
outs += z_p.detach()
|
324 |
+
|
325 |
+
(
|
326 |
+
z_c,
|
327 |
+
codes_c,
|
328 |
+
latents_c,
|
329 |
+
commitment_loss_c,
|
330 |
+
codebook_loss_c,
|
331 |
+
) = self.content_quantizer(x, n_c)
|
332 |
+
outs += z_c.detach()
|
333 |
+
|
334 |
+
timbre_residual_feature = x - z_p.detach() - z_c.detach()
|
335 |
+
|
336 |
+
(
|
337 |
+
z_t,
|
338 |
+
codes_t,
|
339 |
+
latents_t,
|
340 |
+
commitment_loss_t,
|
341 |
+
codebook_loss_t,
|
342 |
+
) = self.timbre_quantizer(timbre_residual_feature, 2)
|
343 |
+
outs += z_t # we should not detach timbre
|
344 |
+
|
345 |
+
residual_feature = timbre_residual_feature - z_t
|
346 |
+
|
347 |
+
(
|
348 |
+
z_r,
|
349 |
+
codes_r,
|
350 |
+
latents_r,
|
351 |
+
commitment_loss_r,
|
352 |
+
codebook_loss_r,
|
353 |
+
) = self.residual_quantizer(residual_feature, 3)
|
354 |
+
|
355 |
+
return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
|
356 |
+
|
357 |
+
def forward(
|
358 |
+
self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2
|
359 |
+
):
|
360 |
+
# timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
|
361 |
+
# timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
|
362 |
+
outs = 0
|
363 |
+
if self.separate_prosody_encoder:
|
364 |
+
prosody_feature = self.preprocess(wave_segments)
|
365 |
+
|
366 |
+
f0_input = prosody_feature # (B, T, 20)
|
367 |
+
f0_input = self.melspec_linear(f0_input)
|
368 |
+
f0_input = self.melspec_encoder(
|
369 |
+
f0_input,
|
370 |
+
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
|
371 |
+
.to(f0_input.device)
|
372 |
+
.bool(),
|
373 |
+
)
|
374 |
+
f0_input = self.melspec_linear2(f0_input)
|
375 |
+
|
376 |
+
common_min_size = min(f0_input.size(2), x.size(2))
|
377 |
+
f0_input = f0_input[:, :, :common_min_size]
|
378 |
+
|
379 |
+
x = x[:, :, :common_min_size]
|
380 |
+
|
381 |
+
(
|
382 |
+
z_p,
|
383 |
+
codes_p,
|
384 |
+
latents_p,
|
385 |
+
commitment_loss_p,
|
386 |
+
codebook_loss_p,
|
387 |
+
) = self.prosody_quantizer(f0_input, 1)
|
388 |
+
outs += z_p.detach()
|
389 |
+
else:
|
390 |
+
(
|
391 |
+
z_p,
|
392 |
+
codes_p,
|
393 |
+
latents_p,
|
394 |
+
commitment_loss_p,
|
395 |
+
codebook_loss_p,
|
396 |
+
) = self.prosody_quantizer(x, 1)
|
397 |
+
outs += z_p.detach()
|
398 |
+
|
399 |
+
(
|
400 |
+
z_c,
|
401 |
+
codes_c,
|
402 |
+
latents_c,
|
403 |
+
commitment_loss_c,
|
404 |
+
codebook_loss_c,
|
405 |
+
) = self.content_quantizer(x, n_c)
|
406 |
+
outs += z_c.detach()
|
407 |
+
|
408 |
+
timbre_residual_feature = x - z_p.detach() - z_c.detach()
|
409 |
+
|
410 |
+
(
|
411 |
+
z_t,
|
412 |
+
codes_t,
|
413 |
+
latents_t,
|
414 |
+
commitment_loss_t,
|
415 |
+
codebook_loss_t,
|
416 |
+
) = self.timbre_quantizer(timbre_residual_feature, n_t)
|
417 |
+
outs += z_t # we should not detach timbre
|
418 |
+
|
419 |
+
residual_feature = timbre_residual_feature - z_t
|
420 |
+
|
421 |
+
(
|
422 |
+
z_r,
|
423 |
+
codes_r,
|
424 |
+
latents_r,
|
425 |
+
commitment_loss_r,
|
426 |
+
codebook_loss_r,
|
427 |
+
) = self.residual_quantizer(residual_feature, 3)
|
428 |
+
|
429 |
+
bsz = z_r.shape[0]
|
430 |
+
res_mask = np.random.choice(
|
431 |
+
[0, 1],
|
432 |
+
size=bsz,
|
433 |
+
p=[
|
434 |
+
self.prob_random_mask_residual,
|
435 |
+
1 - self.prob_random_mask_residual,
|
436 |
+
],
|
437 |
+
)
|
438 |
+
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
|
439 |
+
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
|
440 |
+
noise_must_on = noise_added_flags * recon_noisy_flags
|
441 |
+
noise_must_off = noise_added_flags * (~recon_noisy_flags)
|
442 |
+
res_mask[noise_must_on] = 1
|
443 |
+
res_mask[noise_must_off] = 0
|
444 |
+
|
445 |
+
outs += z_r * res_mask
|
446 |
+
|
447 |
+
quantized = [z_p, z_c, z_t, z_r]
|
448 |
+
commitment_losses = (
|
449 |
+
commitment_loss_p
|
450 |
+
+ commitment_loss_c
|
451 |
+
+ commitment_loss_t
|
452 |
+
+ commitment_loss_r
|
453 |
+
)
|
454 |
+
codebook_losses = (
|
455 |
+
codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
|
456 |
+
)
|
457 |
+
|
458 |
+
return outs, quantized, commitment_losses, codebook_losses
|
459 |
+
|
460 |
+
def forward_v2(
|
461 |
+
self,
|
462 |
+
x,
|
463 |
+
wave_segments,
|
464 |
+
n_c=1,
|
465 |
+
n_t=2,
|
466 |
+
full_waves=None,
|
467 |
+
wave_lens=None,
|
468 |
+
return_codes=False,
|
469 |
+
):
|
470 |
+
# timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
|
471 |
+
if full_waves is None:
|
472 |
+
mel = self.preprocess(wave_segments, n_bins=80)
|
473 |
+
timbre = self.timbre_encoder(
|
474 |
+
mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)
|
475 |
+
)
|
476 |
+
else:
|
477 |
+
mel = self.preprocess(full_waves, n_bins=80)
|
478 |
+
timbre = self.timbre_encoder(
|
479 |
+
mel,
|
480 |
+
sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1),
|
481 |
+
)
|
482 |
+
outs = 0
|
483 |
+
if self.separate_prosody_encoder:
|
484 |
+
prosody_feature = self.preprocess(wave_segments)
|
485 |
+
|
486 |
+
f0_input = prosody_feature # (B, T, 20)
|
487 |
+
f0_input = self.melspec_linear(f0_input)
|
488 |
+
f0_input = self.melspec_encoder(
|
489 |
+
f0_input,
|
490 |
+
torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
|
491 |
+
.to(f0_input.device)
|
492 |
+
.bool(),
|
493 |
+
)
|
494 |
+
f0_input = self.melspec_linear2(f0_input)
|
495 |
+
|
496 |
+
common_min_size = min(f0_input.size(2), x.size(2))
|
497 |
+
f0_input = f0_input[:, :, :common_min_size]
|
498 |
+
|
499 |
+
x = x[:, :, :common_min_size]
|
500 |
+
|
501 |
+
(
|
502 |
+
z_p,
|
503 |
+
codes_p,
|
504 |
+
latents_p,
|
505 |
+
commitment_loss_p,
|
506 |
+
codebook_loss_p,
|
507 |
+
) = self.prosody_quantizer(f0_input, 1)
|
508 |
+
outs += z_p.detach()
|
509 |
+
else:
|
510 |
+
(
|
511 |
+
z_p,
|
512 |
+
codes_p,
|
513 |
+
latents_p,
|
514 |
+
commitment_loss_p,
|
515 |
+
codebook_loss_p,
|
516 |
+
) = self.prosody_quantizer(x, 1)
|
517 |
+
outs += z_p.detach()
|
518 |
+
|
519 |
+
(
|
520 |
+
z_c,
|
521 |
+
codes_c,
|
522 |
+
latents_c,
|
523 |
+
commitment_loss_c,
|
524 |
+
codebook_loss_c,
|
525 |
+
) = self.content_quantizer(x, n_c)
|
526 |
+
outs += z_c.detach()
|
527 |
+
|
528 |
+
residual_feature = x - z_p.detach() - z_c.detach()
|
529 |
+
|
530 |
+
(
|
531 |
+
z_r,
|
532 |
+
codes_r,
|
533 |
+
latents_r,
|
534 |
+
commitment_loss_r,
|
535 |
+
codebook_loss_r,
|
536 |
+
) = self.residual_quantizer(residual_feature, 3)
|
537 |
+
|
538 |
+
bsz = z_r.shape[0]
|
539 |
+
res_mask = np.random.choice(
|
540 |
+
[0, 1],
|
541 |
+
size=bsz,
|
542 |
+
p=[
|
543 |
+
self.prob_random_mask_residual,
|
544 |
+
1 - self.prob_random_mask_residual,
|
545 |
+
],
|
546 |
+
)
|
547 |
+
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
|
548 |
+
res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
|
549 |
+
|
550 |
+
if not self.training:
|
551 |
+
res_mask = torch.ones_like(res_mask)
|
552 |
+
outs += z_r * res_mask
|
553 |
+
|
554 |
+
quantized = [z_p, z_c, z_r]
|
555 |
+
codes = [codes_p, codes_c, codes_r]
|
556 |
+
commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
|
557 |
+
codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
|
558 |
+
|
559 |
+
style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1)
|
560 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
561 |
+
outs = outs.transpose(1, 2)
|
562 |
+
outs = self.timbre_norm(outs)
|
563 |
+
outs = outs.transpose(1, 2)
|
564 |
+
outs = outs * gamma + beta
|
565 |
+
|
566 |
+
if return_codes:
|
567 |
+
return outs, quantized, commitment_losses, codebook_losses, timbre, codes
|
568 |
+
else:
|
569 |
+
return outs, quantized, commitment_losses, codebook_losses, timbre
|
570 |
+
|
571 |
+
def voice_conversion(self, z, ref_wave):
|
572 |
+
ref_mel = self.preprocess(ref_wave, n_bins=80)
|
573 |
+
ref_timbre = self.timbre_encoder(
|
574 |
+
ref_mel,
|
575 |
+
sequence_mask(
|
576 |
+
torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length,
|
577 |
+
ref_mel.size(-1),
|
578 |
+
).unsqueeze(1),
|
579 |
+
)
|
580 |
+
style = self.timbre_linear(ref_timbre).unsqueeze(2) # (B, 2d, 1)
|
581 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
582 |
+
outs = z.transpose(1, 2)
|
583 |
+
outs = self.timbre_norm(outs)
|
584 |
+
outs = outs.transpose(1, 2)
|
585 |
+
outs = outs * gamma + beta
|
586 |
+
|
587 |
+
return outs
|
588 |
+
|
589 |
+
|
590 |
+
class FApredictors(nn.Module):
|
591 |
+
def __init__(
|
592 |
+
self,
|
593 |
+
in_dim=1024,
|
594 |
+
use_gr_content_f0=False,
|
595 |
+
use_gr_prosody_phone=False,
|
596 |
+
use_gr_residual_f0=False,
|
597 |
+
use_gr_residual_phone=False,
|
598 |
+
use_gr_timbre_content=True,
|
599 |
+
use_gr_timbre_prosody=True,
|
600 |
+
use_gr_x_timbre=False,
|
601 |
+
norm_f0=True,
|
602 |
+
timbre_norm=False,
|
603 |
+
use_gr_content_global_f0=False,
|
604 |
+
):
|
605 |
+
super(FApredictors, self).__init__()
|
606 |
+
self.f0_predictor = CNNLSTM(in_dim, 1, 2)
|
607 |
+
self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
|
608 |
+
if timbre_norm:
|
609 |
+
self.timbre_predictor = nn.Linear(in_dim, 20000)
|
610 |
+
else:
|
611 |
+
self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
|
612 |
+
|
613 |
+
self.use_gr_content_f0 = use_gr_content_f0
|
614 |
+
self.use_gr_prosody_phone = use_gr_prosody_phone
|
615 |
+
self.use_gr_residual_f0 = use_gr_residual_f0
|
616 |
+
self.use_gr_residual_phone = use_gr_residual_phone
|
617 |
+
self.use_gr_timbre_content = use_gr_timbre_content
|
618 |
+
self.use_gr_timbre_prosody = use_gr_timbre_prosody
|
619 |
+
self.use_gr_x_timbre = use_gr_x_timbre
|
620 |
+
|
621 |
+
self.rev_f0_predictor = nn.Sequential(
|
622 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
|
623 |
+
)
|
624 |
+
self.rev_content_predictor = nn.Sequential(
|
625 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
|
626 |
+
)
|
627 |
+
self.rev_timbre_predictor = nn.Sequential(
|
628 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
|
629 |
+
)
|
630 |
+
|
631 |
+
self.norm_f0 = norm_f0
|
632 |
+
self.timbre_norm = timbre_norm
|
633 |
+
if timbre_norm:
|
634 |
+
self.forward = self.forward_v2
|
635 |
+
self.global_f0_predictor = nn.Linear(in_dim, 1)
|
636 |
+
|
637 |
+
self.use_gr_content_global_f0 = use_gr_content_global_f0
|
638 |
+
if use_gr_content_global_f0:
|
639 |
+
self.rev_global_f0_predictor = nn.Sequential(
|
640 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
|
641 |
+
)
|
642 |
+
|
643 |
+
def forward(self, quantized):
|
644 |
+
prosody_latent = quantized[0]
|
645 |
+
content_latent = quantized[1]
|
646 |
+
timbre_latent = quantized[2]
|
647 |
+
residual_latent = quantized[3]
|
648 |
+
content_pred = self.phone_predictor(content_latent)[0]
|
649 |
+
|
650 |
+
if self.norm_f0:
|
651 |
+
spk_pred = self.timbre_predictor(timbre_latent)[0]
|
652 |
+
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
|
653 |
+
else:
|
654 |
+
spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
|
655 |
+
f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
|
656 |
+
|
657 |
+
prosody_rev_latent = torch.zeros_like(quantized[0])
|
658 |
+
if self.use_gr_content_f0:
|
659 |
+
prosody_rev_latent += quantized[1]
|
660 |
+
if self.use_gr_timbre_prosody:
|
661 |
+
prosody_rev_latent += quantized[2]
|
662 |
+
if self.use_gr_residual_f0:
|
663 |
+
prosody_rev_latent += quantized[3]
|
664 |
+
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
|
665 |
+
|
666 |
+
content_rev_latent = torch.zeros_like(quantized[1])
|
667 |
+
if self.use_gr_prosody_phone:
|
668 |
+
content_rev_latent += quantized[0]
|
669 |
+
if self.use_gr_timbre_content:
|
670 |
+
content_rev_latent += quantized[2]
|
671 |
+
if self.use_gr_residual_phone:
|
672 |
+
content_rev_latent += quantized[3]
|
673 |
+
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
|
674 |
+
|
675 |
+
if self.norm_f0:
|
676 |
+
timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
|
677 |
+
else:
|
678 |
+
timbre_rev_latent = quantized[1] + quantized[3]
|
679 |
+
if self.use_gr_x_timbre:
|
680 |
+
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
|
681 |
+
else:
|
682 |
+
x_spk_pred = None
|
683 |
+
|
684 |
+
preds = {
|
685 |
+
"f0": f0_pred,
|
686 |
+
"uv": uv_pred,
|
687 |
+
"content": content_pred,
|
688 |
+
"timbre": spk_pred,
|
689 |
+
}
|
690 |
+
|
691 |
+
rev_preds = {
|
692 |
+
"rev_f0": rev_f0_pred,
|
693 |
+
"rev_uv": rev_uv_pred,
|
694 |
+
"rev_content": rev_content_pred,
|
695 |
+
"x_timbre": x_spk_pred,
|
696 |
+
}
|
697 |
+
return preds, rev_preds
|
698 |
+
|
699 |
+
def forward_v2(self, quantized, timbre):
|
700 |
+
prosody_latent = quantized[0]
|
701 |
+
content_latent = quantized[1]
|
702 |
+
residual_latent = quantized[2]
|
703 |
+
content_pred = self.phone_predictor(content_latent)[0]
|
704 |
+
|
705 |
+
spk_pred = self.timbre_predictor(timbre)
|
706 |
+
f0_pred, uv_pred = self.f0_predictor(prosody_latent)
|
707 |
+
|
708 |
+
prosody_rev_latent = torch.zeros_like(prosody_latent)
|
709 |
+
if self.use_gr_content_f0:
|
710 |
+
prosody_rev_latent += content_latent
|
711 |
+
if self.use_gr_residual_f0:
|
712 |
+
prosody_rev_latent += residual_latent
|
713 |
+
rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
|
714 |
+
|
715 |
+
content_rev_latent = torch.zeros_like(content_latent)
|
716 |
+
if self.use_gr_prosody_phone:
|
717 |
+
content_rev_latent += prosody_latent
|
718 |
+
if self.use_gr_residual_phone:
|
719 |
+
content_rev_latent += residual_latent
|
720 |
+
rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
|
721 |
+
|
722 |
+
timbre_rev_latent = prosody_latent + content_latent + residual_latent
|
723 |
+
if self.use_gr_x_timbre:
|
724 |
+
x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
|
725 |
+
else:
|
726 |
+
x_spk_pred = None
|
727 |
+
|
728 |
+
preds = {
|
729 |
+
"f0": f0_pred,
|
730 |
+
"uv": uv_pred,
|
731 |
+
"content": content_pred,
|
732 |
+
"timbre": spk_pred,
|
733 |
+
}
|
734 |
+
|
735 |
+
rev_preds = {
|
736 |
+
"rev_f0": rev_f0_pred,
|
737 |
+
"rev_uv": rev_uv_pred,
|
738 |
+
"rev_content": rev_content_pred,
|
739 |
+
"x_timbre": x_spk_pred,
|
740 |
+
}
|
741 |
+
return preds, rev_preds
|
models/codec/facodec/modules/style_encoder.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
|
7 |
+
|
8 |
+
from . import attentions
|
9 |
+
from torch import nn
|
10 |
+
import torch
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class Mish(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(Mish, self).__init__()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return x * torch.tanh(F.softplus(x))
|
20 |
+
|
21 |
+
|
22 |
+
class Conv1dGLU(nn.Module):
|
23 |
+
"""
|
24 |
+
Conv1d + GLU(Gated Linear Unit) with residual connection.
|
25 |
+
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, in_channels, out_channels, kernel_size, dropout):
|
29 |
+
super(Conv1dGLU, self).__init__()
|
30 |
+
self.out_channels = out_channels
|
31 |
+
self.conv1 = nn.Conv1d(
|
32 |
+
in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2
|
33 |
+
)
|
34 |
+
self.dropout = nn.Dropout(dropout)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = x
|
38 |
+
x = self.conv1(x)
|
39 |
+
x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
|
40 |
+
x = x1 * torch.sigmoid(x2)
|
41 |
+
x = residual + self.dropout(x)
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
class StyleEncoder(torch.nn.Module):
|
46 |
+
def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
|
47 |
+
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
|
51 |
+
self.hidden_dim = hidden_dim
|
52 |
+
self.out_dim = out_dim
|
53 |
+
self.kernel_size = 5
|
54 |
+
self.n_head = 2
|
55 |
+
self.dropout = 0.1
|
56 |
+
|
57 |
+
self.spectral = nn.Sequential(
|
58 |
+
nn.Conv1d(self.in_dim, self.hidden_dim, 1),
|
59 |
+
Mish(),
|
60 |
+
nn.Dropout(self.dropout),
|
61 |
+
nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
|
62 |
+
Mish(),
|
63 |
+
nn.Dropout(self.dropout),
|
64 |
+
)
|
65 |
+
|
66 |
+
self.temporal = nn.Sequential(
|
67 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
68 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
69 |
+
)
|
70 |
+
|
71 |
+
self.slf_attn = attentions.MultiHeadAttention(
|
72 |
+
self.hidden_dim,
|
73 |
+
self.hidden_dim,
|
74 |
+
self.n_head,
|
75 |
+
p_dropout=self.dropout,
|
76 |
+
proximal_bias=False,
|
77 |
+
proximal_init=True,
|
78 |
+
)
|
79 |
+
self.atten_drop = nn.Dropout(self.dropout)
|
80 |
+
self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
|
81 |
+
|
82 |
+
def forward(self, x, mask=None):
|
83 |
+
|
84 |
+
# spectral
|
85 |
+
x = self.spectral(x) * mask
|
86 |
+
# temporal
|
87 |
+
x = self.temporal(x) * mask
|
88 |
+
|
89 |
+
# self-attention
|
90 |
+
attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
|
91 |
+
y = self.slf_attn(x, x, attn_mask=attn_mask)
|
92 |
+
x = x + self.atten_drop(y)
|
93 |
+
|
94 |
+
# fc
|
95 |
+
x = self.fc(x)
|
96 |
+
|
97 |
+
# temoral average pooling
|
98 |
+
w = self.temporal_avg_pool(x, mask=mask)
|
99 |
+
|
100 |
+
return w
|
101 |
+
|
102 |
+
def temporal_avg_pool(self, x, mask=None):
|
103 |
+
if mask is None:
|
104 |
+
out = torch.mean(x, dim=2)
|
105 |
+
else:
|
106 |
+
len_ = mask.sum(dim=2)
|
107 |
+
x = x.sum(dim=2)
|
108 |
+
|
109 |
+
out = torch.div(x, len_)
|
110 |
+
return out
|
models/codec/facodec/modules/wavenet.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
|
7 |
+
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
from modules.dac.model.encodec import SConv1d
|
14 |
+
|
15 |
+
from . import commons
|
16 |
+
|
17 |
+
LRELU_SLOPE = 0.1
|
18 |
+
|
19 |
+
|
20 |
+
class LayerNorm(nn.Module):
|
21 |
+
def __init__(self, channels, eps=1e-5):
|
22 |
+
super().__init__()
|
23 |
+
self.channels = channels
|
24 |
+
self.eps = eps
|
25 |
+
|
26 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
27 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = x.transpose(1, -1)
|
31 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
32 |
+
return x.transpose(1, -1)
|
33 |
+
|
34 |
+
|
35 |
+
class ConvReluNorm(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
in_channels,
|
39 |
+
hidden_channels,
|
40 |
+
out_channels,
|
41 |
+
kernel_size,
|
42 |
+
n_layers,
|
43 |
+
p_dropout,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.in_channels = in_channels
|
47 |
+
self.hidden_channels = hidden_channels
|
48 |
+
self.out_channels = out_channels
|
49 |
+
self.kernel_size = kernel_size
|
50 |
+
self.n_layers = n_layers
|
51 |
+
self.p_dropout = p_dropout
|
52 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
53 |
+
|
54 |
+
self.conv_layers = nn.ModuleList()
|
55 |
+
self.norm_layers = nn.ModuleList()
|
56 |
+
self.conv_layers.append(
|
57 |
+
nn.Conv1d(
|
58 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
62 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
63 |
+
for _ in range(n_layers - 1):
|
64 |
+
self.conv_layers.append(
|
65 |
+
nn.Conv1d(
|
66 |
+
hidden_channels,
|
67 |
+
hidden_channels,
|
68 |
+
kernel_size,
|
69 |
+
padding=kernel_size // 2,
|
70 |
+
)
|
71 |
+
)
|
72 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
73 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
74 |
+
self.proj.weight.data.zero_()
|
75 |
+
self.proj.bias.data.zero_()
|
76 |
+
|
77 |
+
def forward(self, x, x_mask):
|
78 |
+
x_org = x
|
79 |
+
for i in range(self.n_layers):
|
80 |
+
x = self.conv_layers[i](x * x_mask)
|
81 |
+
x = self.norm_layers[i](x)
|
82 |
+
x = self.relu_drop(x)
|
83 |
+
x = x_org + self.proj(x)
|
84 |
+
return x * x_mask
|
85 |
+
|
86 |
+
|
87 |
+
class DDSConv(nn.Module):
|
88 |
+
"""
|
89 |
+
Dialted and Depth-Separable Convolution
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
93 |
+
super().__init__()
|
94 |
+
self.channels = channels
|
95 |
+
self.kernel_size = kernel_size
|
96 |
+
self.n_layers = n_layers
|
97 |
+
self.p_dropout = p_dropout
|
98 |
+
|
99 |
+
self.drop = nn.Dropout(p_dropout)
|
100 |
+
self.convs_sep = nn.ModuleList()
|
101 |
+
self.convs_1x1 = nn.ModuleList()
|
102 |
+
self.norms_1 = nn.ModuleList()
|
103 |
+
self.norms_2 = nn.ModuleList()
|
104 |
+
for i in range(n_layers):
|
105 |
+
dilation = kernel_size**i
|
106 |
+
padding = (kernel_size * dilation - dilation) // 2
|
107 |
+
self.convs_sep.append(
|
108 |
+
nn.Conv1d(
|
109 |
+
channels,
|
110 |
+
channels,
|
111 |
+
kernel_size,
|
112 |
+
groups=channels,
|
113 |
+
dilation=dilation,
|
114 |
+
padding=padding,
|
115 |
+
)
|
116 |
+
)
|
117 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
118 |
+
self.norms_1.append(LayerNorm(channels))
|
119 |
+
self.norms_2.append(LayerNorm(channels))
|
120 |
+
|
121 |
+
def forward(self, x, x_mask, g=None):
|
122 |
+
if g is not None:
|
123 |
+
x = x + g
|
124 |
+
for i in range(self.n_layers):
|
125 |
+
y = self.convs_sep[i](x * x_mask)
|
126 |
+
y = self.norms_1[i](y)
|
127 |
+
y = F.gelu(y)
|
128 |
+
y = self.convs_1x1[i](y)
|
129 |
+
y = self.norms_2[i](y)
|
130 |
+
y = F.gelu(y)
|
131 |
+
y = self.drop(y)
|
132 |
+
x = x + y
|
133 |
+
return x * x_mask
|
134 |
+
|
135 |
+
|
136 |
+
class WN(torch.nn.Module):
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
hidden_channels,
|
140 |
+
kernel_size,
|
141 |
+
dilation_rate,
|
142 |
+
n_layers,
|
143 |
+
gin_channels=0,
|
144 |
+
p_dropout=0,
|
145 |
+
causal=False,
|
146 |
+
):
|
147 |
+
super(WN, self).__init__()
|
148 |
+
conv1d_type = SConv1d
|
149 |
+
assert kernel_size % 2 == 1
|
150 |
+
self.hidden_channels = hidden_channels
|
151 |
+
self.kernel_size = (kernel_size,)
|
152 |
+
self.dilation_rate = dilation_rate
|
153 |
+
self.n_layers = n_layers
|
154 |
+
self.gin_channels = gin_channels
|
155 |
+
self.p_dropout = p_dropout
|
156 |
+
|
157 |
+
self.in_layers = torch.nn.ModuleList()
|
158 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
159 |
+
self.drop = nn.Dropout(p_dropout)
|
160 |
+
|
161 |
+
if gin_channels != 0:
|
162 |
+
self.cond_layer = conv1d_type(
|
163 |
+
gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm"
|
164 |
+
)
|
165 |
+
|
166 |
+
for i in range(n_layers):
|
167 |
+
dilation = dilation_rate**i
|
168 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
169 |
+
in_layer = conv1d_type(
|
170 |
+
hidden_channels,
|
171 |
+
2 * hidden_channels,
|
172 |
+
kernel_size,
|
173 |
+
dilation=dilation,
|
174 |
+
padding=padding,
|
175 |
+
norm="weight_norm",
|
176 |
+
causal=causal,
|
177 |
+
)
|
178 |
+
self.in_layers.append(in_layer)
|
179 |
+
|
180 |
+
# last one is not necessary
|
181 |
+
if i < n_layers - 1:
|
182 |
+
res_skip_channels = 2 * hidden_channels
|
183 |
+
else:
|
184 |
+
res_skip_channels = hidden_channels
|
185 |
+
|
186 |
+
res_skip_layer = conv1d_type(
|
187 |
+
hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal
|
188 |
+
)
|
189 |
+
self.res_skip_layers.append(res_skip_layer)
|
190 |
+
|
191 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
192 |
+
output = torch.zeros_like(x)
|
193 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
194 |
+
|
195 |
+
if g is not None:
|
196 |
+
g = self.cond_layer(g)
|
197 |
+
|
198 |
+
for i in range(self.n_layers):
|
199 |
+
x_in = self.in_layers[i](x)
|
200 |
+
if g is not None:
|
201 |
+
cond_offset = i * 2 * self.hidden_channels
|
202 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
203 |
+
else:
|
204 |
+
g_l = torch.zeros_like(x_in)
|
205 |
+
|
206 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
207 |
+
acts = self.drop(acts)
|
208 |
+
|
209 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
210 |
+
if i < self.n_layers - 1:
|
211 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
212 |
+
x = (x + res_acts) * x_mask
|
213 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
214 |
+
else:
|
215 |
+
output = output + res_skip_acts
|
216 |
+
return output * x_mask
|
217 |
+
|
218 |
+
def remove_weight_norm(self):
|
219 |
+
if self.gin_channels != 0:
|
220 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
221 |
+
for l in self.in_layers:
|
222 |
+
torch.nn.utils.remove_weight_norm(l)
|
223 |
+
for l in self.res_skip_layers:
|
224 |
+
torch.nn.utils.remove_weight_norm(l)
|
models/codec/facodec/optimizer.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os, sys
|
7 |
+
import os.path as osp
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from torch.optim import Optimizer
|
12 |
+
from functools import reduce
|
13 |
+
from torch.optim import AdamW
|
14 |
+
|
15 |
+
|
16 |
+
class MultiOptimizer:
|
17 |
+
def __init__(self, optimizers={}, schedulers={}):
|
18 |
+
self.optimizers = optimizers
|
19 |
+
self.schedulers = schedulers
|
20 |
+
self.keys = list(optimizers.keys())
|
21 |
+
self.param_groups = reduce(
|
22 |
+
lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
|
23 |
+
)
|
24 |
+
|
25 |
+
def state_dict(self):
|
26 |
+
state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
|
27 |
+
return state_dicts
|
28 |
+
|
29 |
+
def scheduler_state_dict(self):
|
30 |
+
state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys]
|
31 |
+
return state_dicts
|
32 |
+
|
33 |
+
def load_state_dict(self, state_dict):
|
34 |
+
for key, val in state_dict:
|
35 |
+
try:
|
36 |
+
self.optimizers[key].load_state_dict(val)
|
37 |
+
except:
|
38 |
+
print("Unloaded %s" % key)
|
39 |
+
|
40 |
+
def load_scheduler_state_dict(self, state_dict):
|
41 |
+
for key, val in state_dict:
|
42 |
+
try:
|
43 |
+
self.schedulers[key].load_state_dict(val)
|
44 |
+
except:
|
45 |
+
print("Unloaded %s" % key)
|
46 |
+
|
47 |
+
def step(self, key=None, scaler=None):
|
48 |
+
keys = [key] if key is not None else self.keys
|
49 |
+
_ = [self._step(key, scaler) for key in keys]
|
50 |
+
|
51 |
+
def _step(self, key, scaler=None):
|
52 |
+
if scaler is not None:
|
53 |
+
scaler.step(self.optimizers[key])
|
54 |
+
scaler.update()
|
55 |
+
else:
|
56 |
+
self.optimizers[key].step()
|
57 |
+
|
58 |
+
def zero_grad(self, key=None):
|
59 |
+
if key is not None:
|
60 |
+
self.optimizers[key].zero_grad()
|
61 |
+
else:
|
62 |
+
_ = [self.optimizers[key].zero_grad() for key in self.keys]
|
63 |
+
|
64 |
+
def scheduler(self, *args, key=None):
|
65 |
+
if key is not None:
|
66 |
+
self.schedulers[key].step(*args)
|
67 |
+
else:
|
68 |
+
_ = [self.schedulers[key].step_batch(*args) for key in self.keys]
|
69 |
+
|
70 |
+
|
71 |
+
def define_scheduler(optimizer, params):
|
72 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"])
|
73 |
+
|
74 |
+
return scheduler
|
75 |
+
|
76 |
+
|
77 |
+
def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"):
|
78 |
+
optim = {}
|
79 |
+
for key, model in model_dict.items():
|
80 |
+
model_parameters = model.parameters()
|
81 |
+
parameters_names = []
|
82 |
+
parameters_names.append(
|
83 |
+
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
84 |
+
)
|
85 |
+
if type == "AdamW":
|
86 |
+
optim[key] = AdamW(
|
87 |
+
model_parameters,
|
88 |
+
lr=lr,
|
89 |
+
betas=(0.9, 0.98),
|
90 |
+
eps=1e-9,
|
91 |
+
weight_decay=0.1,
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
raise ValueError("Unknown optimizer type: %s" % type)
|
95 |
+
|
96 |
+
schedulers = dict(
|
97 |
+
[
|
98 |
+
(key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
|
99 |
+
for key, opt in optim.items()
|
100 |
+
]
|
101 |
+
)
|
102 |
+
|
103 |
+
multi_optim = MultiOptimizer(optim, schedulers)
|
104 |
+
return multi_optim
|
models/codec/kmeans/repcodec_model.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from concurrent.futures import ALL_COMPLETED
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
|
14 |
+
from models.codec.amphion_codec.quantize import ResidualVQ
|
15 |
+
from models.codec.kmeans.vocos import VocosBackbone
|
16 |
+
|
17 |
+
|
18 |
+
def init_weights(m):
|
19 |
+
if isinstance(m, nn.Conv1d):
|
20 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
21 |
+
nn.init.constant_(m.bias, 0)
|
22 |
+
if isinstance(m, nn.Linear):
|
23 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
24 |
+
nn.init.constant_(m.bias, 0)
|
25 |
+
|
26 |
+
|
27 |
+
def compute_codebook_perplexity(indices, codebook_size):
|
28 |
+
indices = indices.flatten()
|
29 |
+
prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
|
30 |
+
perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
|
31 |
+
return perp
|
32 |
+
|
33 |
+
|
34 |
+
class RepCodec(nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
codebook_size=8192,
|
38 |
+
hidden_size=1024,
|
39 |
+
codebook_dim=8,
|
40 |
+
vocos_dim=384,
|
41 |
+
vocos_intermediate_dim=2048,
|
42 |
+
vocos_num_layers=12,
|
43 |
+
num_quantizers=1,
|
44 |
+
downsample_scale=1,
|
45 |
+
cfg=None,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
codebook_size = (
|
49 |
+
cfg.codebook_size
|
50 |
+
if cfg is not None and hasattr(cfg, "codebook_size")
|
51 |
+
else codebook_size
|
52 |
+
)
|
53 |
+
codebook_dim = (
|
54 |
+
cfg.codebook_dim
|
55 |
+
if cfg is not None and hasattr(cfg, "codebook_dim")
|
56 |
+
else codebook_dim
|
57 |
+
)
|
58 |
+
hidden_size = (
|
59 |
+
cfg.hidden_size
|
60 |
+
if cfg is not None and hasattr(cfg, "hidden_size")
|
61 |
+
else hidden_size
|
62 |
+
)
|
63 |
+
vocos_dim = (
|
64 |
+
cfg.vocos_dim
|
65 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
66 |
+
else vocos_dim
|
67 |
+
)
|
68 |
+
vocos_intermediate_dim = (
|
69 |
+
cfg.vocos_intermediate_dim
|
70 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
71 |
+
else vocos_intermediate_dim
|
72 |
+
)
|
73 |
+
vocos_num_layers = (
|
74 |
+
cfg.vocos_num_layers
|
75 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
76 |
+
else vocos_num_layers
|
77 |
+
)
|
78 |
+
num_quantizers = (
|
79 |
+
cfg.num_quantizers
|
80 |
+
if cfg is not None and hasattr(cfg, "num_quantizers")
|
81 |
+
else num_quantizers
|
82 |
+
)
|
83 |
+
downsample_scale = (
|
84 |
+
cfg.downsample_scale
|
85 |
+
if cfg is not None and hasattr(cfg, "downsample_scale")
|
86 |
+
else downsample_scale
|
87 |
+
)
|
88 |
+
|
89 |
+
self.codebook_size = codebook_size
|
90 |
+
self.codebook_dim = codebook_dim
|
91 |
+
self.hidden_size = hidden_size
|
92 |
+
self.vocos_dim = vocos_dim
|
93 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
94 |
+
self.vocos_num_layers = vocos_num_layers
|
95 |
+
self.num_quantizers = num_quantizers
|
96 |
+
self.downsample_scale = downsample_scale
|
97 |
+
|
98 |
+
if self.downsample_scale != None and self.downsample_scale > 1:
|
99 |
+
self.down = nn.Conv1d(
|
100 |
+
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
|
101 |
+
)
|
102 |
+
self.up = nn.Conv1d(
|
103 |
+
self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
|
104 |
+
)
|
105 |
+
|
106 |
+
self.encoder = nn.Sequential(
|
107 |
+
VocosBackbone(
|
108 |
+
input_channels=self.hidden_size,
|
109 |
+
dim=self.vocos_dim,
|
110 |
+
intermediate_dim=self.vocos_intermediate_dim,
|
111 |
+
num_layers=self.vocos_num_layers,
|
112 |
+
adanorm_num_embeddings=None,
|
113 |
+
),
|
114 |
+
nn.Linear(self.vocos_dim, self.hidden_size),
|
115 |
+
)
|
116 |
+
self.decoder = nn.Sequential(
|
117 |
+
VocosBackbone(
|
118 |
+
input_channels=self.hidden_size,
|
119 |
+
dim=self.vocos_dim,
|
120 |
+
intermediate_dim=self.vocos_intermediate_dim,
|
121 |
+
num_layers=self.vocos_num_layers,
|
122 |
+
adanorm_num_embeddings=None,
|
123 |
+
),
|
124 |
+
nn.Linear(self.vocos_dim, self.hidden_size),
|
125 |
+
)
|
126 |
+
|
127 |
+
self.quantizer = ResidualVQ(
|
128 |
+
input_dim=hidden_size,
|
129 |
+
num_quantizers=num_quantizers,
|
130 |
+
codebook_size=codebook_size,
|
131 |
+
codebook_dim=codebook_dim,
|
132 |
+
quantizer_type="fvq",
|
133 |
+
quantizer_dropout=0.0,
|
134 |
+
commitment=0.15,
|
135 |
+
codebook_loss_weight=1.0,
|
136 |
+
use_l2_normlize=True,
|
137 |
+
)
|
138 |
+
|
139 |
+
self.reset_parameters()
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
|
143 |
+
# downsample
|
144 |
+
if self.downsample_scale != None and self.downsample_scale > 1:
|
145 |
+
x = x.transpose(1, 2)
|
146 |
+
x = self.down(x)
|
147 |
+
x = F.gelu(x)
|
148 |
+
x = x.transpose(1, 2)
|
149 |
+
|
150 |
+
# encoder
|
151 |
+
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
152 |
+
|
153 |
+
# vq
|
154 |
+
(
|
155 |
+
quantized_out,
|
156 |
+
all_indices,
|
157 |
+
all_commit_losses,
|
158 |
+
all_codebook_losses,
|
159 |
+
_,
|
160 |
+
) = self.quantizer(x)
|
161 |
+
|
162 |
+
# decoder
|
163 |
+
x = self.decoder(quantized_out)
|
164 |
+
|
165 |
+
# up
|
166 |
+
if self.downsample_scale != None and self.downsample_scale > 1:
|
167 |
+
x = x.transpose(1, 2)
|
168 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
169 |
+
x_rec = self.up(x).transpose(1, 2)
|
170 |
+
|
171 |
+
codebook_loss = (all_codebook_losses + all_commit_losses).mean()
|
172 |
+
all_indices = all_indices
|
173 |
+
|
174 |
+
return x_rec, codebook_loss, all_indices
|
175 |
+
|
176 |
+
def quantize(self, x):
|
177 |
+
|
178 |
+
if self.downsample_scale != None and self.downsample_scale > 1:
|
179 |
+
x = x.transpose(1, 2)
|
180 |
+
x = self.down(x)
|
181 |
+
x = F.gelu(x)
|
182 |
+
x = x.transpose(1, 2)
|
183 |
+
|
184 |
+
x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
|
185 |
+
|
186 |
+
(
|
187 |
+
quantized_out,
|
188 |
+
all_indices,
|
189 |
+
all_commit_losses,
|
190 |
+
all_codebook_losses,
|
191 |
+
_,
|
192 |
+
) = self.quantizer(x)
|
193 |
+
|
194 |
+
if all_indices.shape[0] == 1:
|
195 |
+
return all_indices.squeeze(0), quantized_out.transpose(1, 2)
|
196 |
+
return all_indices, quantized_out.transpose(1, 2)
|
197 |
+
|
198 |
+
def reset_parameters(self):
|
199 |
+
self.apply(init_weights)
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
|
204 |
+
print(repcodec)
|
205 |
+
print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
|
206 |
+
x = torch.randn(5, 10, 1024)
|
207 |
+
x_rec, codebook_loss, all_indices = repcodec(x)
|
208 |
+
print(x_rec.shape, codebook_loss, all_indices.shape)
|
209 |
+
vq_id, emb = repcodec.quantize(x)
|
210 |
+
print(vq_id.shape, emb.shape)
|
models/codec/kmeans/vocos.py
ADDED
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import scipy
|
10 |
+
import torch
|
11 |
+
from torch import nn, view_as_real, view_as_complex
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
14 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
15 |
+
|
16 |
+
|
17 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
18 |
+
"""
|
19 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
x (Tensor): Input tensor.
|
23 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
27 |
+
"""
|
28 |
+
return torch.log(torch.clip(x, min=clip_val))
|
29 |
+
|
30 |
+
|
31 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
32 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
33 |
+
|
34 |
+
|
35 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
36 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
37 |
+
|
38 |
+
|
39 |
+
class STFT(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
n_fft: int,
|
43 |
+
hop_length: int,
|
44 |
+
win_length: int,
|
45 |
+
center=True,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.center = center
|
49 |
+
self.n_fft = n_fft
|
50 |
+
self.hop_length = hop_length
|
51 |
+
self.win_length = win_length
|
52 |
+
window = torch.hann_window(win_length)
|
53 |
+
self.register_buffer("window", window)
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56 |
+
# x: (B, T * hop_length)
|
57 |
+
|
58 |
+
if not self.center:
|
59 |
+
pad = self.win_length - self.hop_length
|
60 |
+
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
|
61 |
+
|
62 |
+
stft_spec = torch.stft(
|
63 |
+
x,
|
64 |
+
self.n_fft,
|
65 |
+
hop_length=self.hop_length,
|
66 |
+
win_length=self.win_length,
|
67 |
+
window=self.window,
|
68 |
+
center=self.center,
|
69 |
+
return_complex=False,
|
70 |
+
) # (B, n_fft // 2 + 1, T, 2)
|
71 |
+
|
72 |
+
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
|
73 |
+
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
|
74 |
+
|
75 |
+
log_mag = torch.log(
|
76 |
+
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
77 |
+
) # (B, n_fft // 2 + 1, T)
|
78 |
+
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
|
79 |
+
|
80 |
+
return log_mag, phase
|
81 |
+
|
82 |
+
|
83 |
+
class ISTFT(nn.Module):
|
84 |
+
"""
|
85 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
86 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
87 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
88 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
89 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
n_fft (int): Size of Fourier transform.
|
93 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
94 |
+
win_length (int): The size of window frame and STFT filter.
|
95 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
100 |
+
):
|
101 |
+
super().__init__()
|
102 |
+
if padding not in ["center", "same"]:
|
103 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
104 |
+
self.padding = padding
|
105 |
+
self.n_fft = n_fft
|
106 |
+
self.hop_length = hop_length
|
107 |
+
self.win_length = win_length
|
108 |
+
window = torch.hann_window(win_length)
|
109 |
+
self.register_buffer("window", window)
|
110 |
+
|
111 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
112 |
+
"""
|
113 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
117 |
+
N is the number of frequency bins, and T is the number of time frames.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
121 |
+
"""
|
122 |
+
if self.padding == "center":
|
123 |
+
# Fallback to pytorch native implementation
|
124 |
+
return torch.istft(
|
125 |
+
spec,
|
126 |
+
self.n_fft,
|
127 |
+
self.hop_length,
|
128 |
+
self.win_length,
|
129 |
+
self.window,
|
130 |
+
center=True,
|
131 |
+
)
|
132 |
+
elif self.padding == "same":
|
133 |
+
pad = (self.win_length - self.hop_length) // 2
|
134 |
+
else:
|
135 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
136 |
+
|
137 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
138 |
+
B, N, T = spec.shape
|
139 |
+
|
140 |
+
# Inverse FFT
|
141 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
142 |
+
ifft = ifft * self.window[None, :, None]
|
143 |
+
|
144 |
+
# Overlap and Add
|
145 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
146 |
+
y = torch.nn.functional.fold(
|
147 |
+
ifft,
|
148 |
+
output_size=(1, output_size),
|
149 |
+
kernel_size=(1, self.win_length),
|
150 |
+
stride=(1, self.hop_length),
|
151 |
+
)[:, 0, 0, pad:-pad]
|
152 |
+
|
153 |
+
# Window envelope
|
154 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
155 |
+
window_envelope = torch.nn.functional.fold(
|
156 |
+
window_sq,
|
157 |
+
output_size=(1, output_size),
|
158 |
+
kernel_size=(1, self.win_length),
|
159 |
+
stride=(1, self.hop_length),
|
160 |
+
).squeeze()[pad:-pad]
|
161 |
+
|
162 |
+
# Normalize
|
163 |
+
assert (window_envelope > 1e-11).all()
|
164 |
+
y = y / window_envelope
|
165 |
+
|
166 |
+
return y
|
167 |
+
|
168 |
+
|
169 |
+
class MDCT(nn.Module):
|
170 |
+
"""
|
171 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
frame_len (int): Length of the MDCT frame.
|
175 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
179 |
+
super().__init__()
|
180 |
+
if padding not in ["center", "same"]:
|
181 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
182 |
+
self.padding = padding
|
183 |
+
self.frame_len = frame_len
|
184 |
+
N = frame_len // 2
|
185 |
+
n0 = (N + 1) / 2
|
186 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
187 |
+
self.register_buffer("window", window)
|
188 |
+
|
189 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
190 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
191 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
192 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
193 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
194 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
195 |
+
|
196 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
197 |
+
"""
|
198 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
202 |
+
and T is the length of the audio.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
206 |
+
and N is the number of frequency bins.
|
207 |
+
"""
|
208 |
+
if self.padding == "center":
|
209 |
+
audio = torch.nn.functional.pad(
|
210 |
+
audio, (self.frame_len // 2, self.frame_len // 2)
|
211 |
+
)
|
212 |
+
elif self.padding == "same":
|
213 |
+
# hop_length is 1/2 frame_len
|
214 |
+
audio = torch.nn.functional.pad(
|
215 |
+
audio, (self.frame_len // 4, self.frame_len // 4)
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
219 |
+
|
220 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
221 |
+
N = self.frame_len // 2
|
222 |
+
x = x * self.window.expand(x.shape)
|
223 |
+
X = torch.fft.fft(
|
224 |
+
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
|
225 |
+
)[..., :N]
|
226 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
227 |
+
return torch.real(res) * np.sqrt(2)
|
228 |
+
|
229 |
+
|
230 |
+
class IMDCT(nn.Module):
|
231 |
+
"""
|
232 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
frame_len (int): Length of the MDCT frame.
|
236 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
240 |
+
super().__init__()
|
241 |
+
if padding not in ["center", "same"]:
|
242 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
243 |
+
self.padding = padding
|
244 |
+
self.frame_len = frame_len
|
245 |
+
N = frame_len // 2
|
246 |
+
n0 = (N + 1) / 2
|
247 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
248 |
+
self.register_buffer("window", window)
|
249 |
+
|
250 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
251 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
252 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
253 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
254 |
+
|
255 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
256 |
+
"""
|
257 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
261 |
+
L is the number of frames, and N is the number of frequency bins.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
265 |
+
"""
|
266 |
+
B, L, N = X.shape
|
267 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
268 |
+
Y[..., :N] = X
|
269 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
270 |
+
y = torch.fft.ifft(
|
271 |
+
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
|
272 |
+
)
|
273 |
+
y = (
|
274 |
+
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
|
275 |
+
* np.sqrt(N)
|
276 |
+
* np.sqrt(2)
|
277 |
+
)
|
278 |
+
result = y * self.window.expand(y.shape)
|
279 |
+
output_size = (1, (L + 1) * N)
|
280 |
+
audio = torch.nn.functional.fold(
|
281 |
+
result.transpose(1, 2),
|
282 |
+
output_size=output_size,
|
283 |
+
kernel_size=(1, self.frame_len),
|
284 |
+
stride=(1, self.frame_len // 2),
|
285 |
+
)[:, 0, 0, :]
|
286 |
+
|
287 |
+
if self.padding == "center":
|
288 |
+
pad = self.frame_len // 2
|
289 |
+
elif self.padding == "same":
|
290 |
+
pad = self.frame_len // 4
|
291 |
+
else:
|
292 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
293 |
+
|
294 |
+
audio = audio[:, pad:-pad]
|
295 |
+
return audio
|
296 |
+
|
297 |
+
|
298 |
+
class FourierHead(nn.Module):
|
299 |
+
"""Base class for inverse fourier modules."""
|
300 |
+
|
301 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
302 |
+
"""
|
303 |
+
Args:
|
304 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
305 |
+
L is the sequence length, and H denotes the model dimension.
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
309 |
+
"""
|
310 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
311 |
+
|
312 |
+
|
313 |
+
class ISTFTHead(FourierHead):
|
314 |
+
"""
|
315 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
dim (int): Hidden dimension of the model.
|
319 |
+
n_fft (int): Size of Fourier transform.
|
320 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
321 |
+
the resolution of the input features.
|
322 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
323 |
+
"""
|
324 |
+
|
325 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
326 |
+
super().__init__()
|
327 |
+
out_dim = n_fft + 2
|
328 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
329 |
+
self.istft = ISTFT(
|
330 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
331 |
+
)
|
332 |
+
|
333 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
334 |
+
"""
|
335 |
+
Forward pass of the ISTFTHead module.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
339 |
+
L is the sequence length, and H denotes the model dimension.
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
343 |
+
"""
|
344 |
+
x = self.out(x).transpose(1, 2)
|
345 |
+
mag, p = x.chunk(2, dim=1)
|
346 |
+
mag = torch.exp(mag)
|
347 |
+
mag = torch.clip(
|
348 |
+
mag, max=1e2
|
349 |
+
) # safeguard to prevent excessively large magnitudes
|
350 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
351 |
+
x = torch.cos(p)
|
352 |
+
y = torch.sin(p)
|
353 |
+
# recalculating phase here does not produce anything new
|
354 |
+
# only costs time
|
355 |
+
# phase = torch.atan2(y, x)
|
356 |
+
# S = mag * torch.exp(phase * 1j)
|
357 |
+
# better directly produce the complex value
|
358 |
+
S = mag * (x + 1j * y)
|
359 |
+
audio = self.istft(S)
|
360 |
+
return audio
|
361 |
+
|
362 |
+
|
363 |
+
class IMDCTSymExpHead(FourierHead):
|
364 |
+
"""
|
365 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
366 |
+
|
367 |
+
Args:
|
368 |
+
dim (int): Hidden dimension of the model.
|
369 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
370 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
371 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
372 |
+
based on perceptual scaling. Defaults to None.
|
373 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
374 |
+
"""
|
375 |
+
|
376 |
+
def __init__(
|
377 |
+
self,
|
378 |
+
dim: int,
|
379 |
+
mdct_frame_len: int,
|
380 |
+
padding: str = "same",
|
381 |
+
sample_rate: Optional[int] = None,
|
382 |
+
clip_audio: bool = False,
|
383 |
+
):
|
384 |
+
super().__init__()
|
385 |
+
out_dim = mdct_frame_len // 2
|
386 |
+
self.out = nn.Linear(dim, out_dim)
|
387 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
388 |
+
self.clip_audio = clip_audio
|
389 |
+
|
390 |
+
if sample_rate is not None:
|
391 |
+
# optionally init the last layer following mel-scale
|
392 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
393 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
394 |
+
f_pts = _mel_to_hz(m_pts)
|
395 |
+
scale = 1 - (f_pts / f_pts.max())
|
396 |
+
|
397 |
+
with torch.no_grad():
|
398 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
399 |
+
|
400 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
401 |
+
"""
|
402 |
+
Forward pass of the IMDCTSymExpHead module.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
406 |
+
L is the sequence length, and H denotes the model dimension.
|
407 |
+
|
408 |
+
Returns:
|
409 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
410 |
+
"""
|
411 |
+
x = self.out(x)
|
412 |
+
x = symexp(x)
|
413 |
+
x = torch.clip(
|
414 |
+
x, min=-1e2, max=1e2
|
415 |
+
) # safeguard to prevent excessively large magnitudes
|
416 |
+
audio = self.imdct(x)
|
417 |
+
if self.clip_audio:
|
418 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
419 |
+
|
420 |
+
return audio
|
421 |
+
|
422 |
+
|
423 |
+
class IMDCTCosHead(FourierHead):
|
424 |
+
"""
|
425 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
426 |
+
|
427 |
+
Args:
|
428 |
+
dim (int): Hidden dimension of the model.
|
429 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
430 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
431 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(
|
435 |
+
self,
|
436 |
+
dim: int,
|
437 |
+
mdct_frame_len: int,
|
438 |
+
padding: str = "same",
|
439 |
+
clip_audio: bool = False,
|
440 |
+
):
|
441 |
+
super().__init__()
|
442 |
+
self.clip_audio = clip_audio
|
443 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
444 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
445 |
+
|
446 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
447 |
+
"""
|
448 |
+
Forward pass of the IMDCTCosHead module.
|
449 |
+
|
450 |
+
Args:
|
451 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
452 |
+
L is the sequence length, and H denotes the model dimension.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
456 |
+
"""
|
457 |
+
x = self.out(x)
|
458 |
+
m, p = x.chunk(2, dim=2)
|
459 |
+
m = torch.exp(m).clip(
|
460 |
+
max=1e2
|
461 |
+
) # safeguard to prevent excessively large magnitudes
|
462 |
+
audio = self.imdct(m * torch.cos(p))
|
463 |
+
if self.clip_audio:
|
464 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
465 |
+
return audio
|
466 |
+
|
467 |
+
|
468 |
+
class ConvNeXtBlock(nn.Module):
|
469 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
470 |
+
|
471 |
+
Args:
|
472 |
+
dim (int): Number of input channels.
|
473 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
474 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
475 |
+
Defaults to None.
|
476 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
477 |
+
None means non-conditional LayerNorm. Defaults to None.
|
478 |
+
"""
|
479 |
+
|
480 |
+
def __init__(
|
481 |
+
self,
|
482 |
+
dim: int,
|
483 |
+
intermediate_dim: int,
|
484 |
+
layer_scale_init_value: float,
|
485 |
+
adanorm_num_embeddings: Optional[int] = None,
|
486 |
+
):
|
487 |
+
super().__init__()
|
488 |
+
self.dwconv = nn.Conv1d(
|
489 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
490 |
+
) # depthwise conv
|
491 |
+
self.adanorm = adanorm_num_embeddings is not None
|
492 |
+
if adanorm_num_embeddings:
|
493 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
494 |
+
else:
|
495 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
496 |
+
self.pwconv1 = nn.Linear(
|
497 |
+
dim, intermediate_dim
|
498 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
499 |
+
self.act = nn.GELU()
|
500 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
501 |
+
self.gamma = (
|
502 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
503 |
+
if layer_scale_init_value > 0
|
504 |
+
else None
|
505 |
+
)
|
506 |
+
|
507 |
+
def forward(
|
508 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
509 |
+
) -> torch.Tensor:
|
510 |
+
residual = x
|
511 |
+
x = self.dwconv(x)
|
512 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
513 |
+
if self.adanorm:
|
514 |
+
assert cond_embedding_id is not None
|
515 |
+
x = self.norm(x, cond_embedding_id)
|
516 |
+
else:
|
517 |
+
x = self.norm(x)
|
518 |
+
x = self.pwconv1(x)
|
519 |
+
x = self.act(x)
|
520 |
+
x = self.pwconv2(x)
|
521 |
+
if self.gamma is not None:
|
522 |
+
x = self.gamma * x
|
523 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
524 |
+
|
525 |
+
x = residual + x
|
526 |
+
return x
|
527 |
+
|
528 |
+
|
529 |
+
class AdaLayerNorm(nn.Module):
|
530 |
+
"""
|
531 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
532 |
+
|
533 |
+
Args:
|
534 |
+
num_embeddings (int): Number of embeddings.
|
535 |
+
embedding_dim (int): Dimension of the embeddings.
|
536 |
+
"""
|
537 |
+
|
538 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
539 |
+
super().__init__()
|
540 |
+
self.eps = eps
|
541 |
+
self.dim = embedding_dim
|
542 |
+
self.scale = nn.Embedding(
|
543 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
544 |
+
)
|
545 |
+
self.shift = nn.Embedding(
|
546 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
547 |
+
)
|
548 |
+
torch.nn.init.ones_(self.scale.weight)
|
549 |
+
torch.nn.init.zeros_(self.shift.weight)
|
550 |
+
|
551 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
552 |
+
scale = self.scale(cond_embedding_id)
|
553 |
+
shift = self.shift(cond_embedding_id)
|
554 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
555 |
+
x = x * scale + shift
|
556 |
+
return x
|
557 |
+
|
558 |
+
|
559 |
+
class ResBlock1(nn.Module):
|
560 |
+
"""
|
561 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
562 |
+
but without upsampling layers.
|
563 |
+
|
564 |
+
Args:
|
565 |
+
dim (int): Number of input channels.
|
566 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
567 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
568 |
+
Defaults to (1, 3, 5).
|
569 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
570 |
+
Defaults to 0.1.
|
571 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
572 |
+
Defaults to None.
|
573 |
+
"""
|
574 |
+
|
575 |
+
def __init__(
|
576 |
+
self,
|
577 |
+
dim: int,
|
578 |
+
kernel_size: int = 3,
|
579 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
580 |
+
lrelu_slope: float = 0.1,
|
581 |
+
layer_scale_init_value: Optional[float] = None,
|
582 |
+
):
|
583 |
+
super().__init__()
|
584 |
+
self.lrelu_slope = lrelu_slope
|
585 |
+
self.convs1 = nn.ModuleList(
|
586 |
+
[
|
587 |
+
weight_norm(
|
588 |
+
nn.Conv1d(
|
589 |
+
dim,
|
590 |
+
dim,
|
591 |
+
kernel_size,
|
592 |
+
1,
|
593 |
+
dilation=dilation[0],
|
594 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
595 |
+
)
|
596 |
+
),
|
597 |
+
weight_norm(
|
598 |
+
nn.Conv1d(
|
599 |
+
dim,
|
600 |
+
dim,
|
601 |
+
kernel_size,
|
602 |
+
1,
|
603 |
+
dilation=dilation[1],
|
604 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
605 |
+
)
|
606 |
+
),
|
607 |
+
weight_norm(
|
608 |
+
nn.Conv1d(
|
609 |
+
dim,
|
610 |
+
dim,
|
611 |
+
kernel_size,
|
612 |
+
1,
|
613 |
+
dilation=dilation[2],
|
614 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
615 |
+
)
|
616 |
+
),
|
617 |
+
]
|
618 |
+
)
|
619 |
+
|
620 |
+
self.convs2 = nn.ModuleList(
|
621 |
+
[
|
622 |
+
weight_norm(
|
623 |
+
nn.Conv1d(
|
624 |
+
dim,
|
625 |
+
dim,
|
626 |
+
kernel_size,
|
627 |
+
1,
|
628 |
+
dilation=1,
|
629 |
+
padding=self.get_padding(kernel_size, 1),
|
630 |
+
)
|
631 |
+
),
|
632 |
+
weight_norm(
|
633 |
+
nn.Conv1d(
|
634 |
+
dim,
|
635 |
+
dim,
|
636 |
+
kernel_size,
|
637 |
+
1,
|
638 |
+
dilation=1,
|
639 |
+
padding=self.get_padding(kernel_size, 1),
|
640 |
+
)
|
641 |
+
),
|
642 |
+
weight_norm(
|
643 |
+
nn.Conv1d(
|
644 |
+
dim,
|
645 |
+
dim,
|
646 |
+
kernel_size,
|
647 |
+
1,
|
648 |
+
dilation=1,
|
649 |
+
padding=self.get_padding(kernel_size, 1),
|
650 |
+
)
|
651 |
+
),
|
652 |
+
]
|
653 |
+
)
|
654 |
+
|
655 |
+
self.gamma = nn.ParameterList(
|
656 |
+
[
|
657 |
+
(
|
658 |
+
nn.Parameter(
|
659 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
660 |
+
)
|
661 |
+
if layer_scale_init_value is not None
|
662 |
+
else None
|
663 |
+
),
|
664 |
+
(
|
665 |
+
nn.Parameter(
|
666 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
667 |
+
)
|
668 |
+
if layer_scale_init_value is not None
|
669 |
+
else None
|
670 |
+
),
|
671 |
+
(
|
672 |
+
nn.Parameter(
|
673 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
674 |
+
)
|
675 |
+
if layer_scale_init_value is not None
|
676 |
+
else None
|
677 |
+
),
|
678 |
+
]
|
679 |
+
)
|
680 |
+
|
681 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
682 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
683 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
684 |
+
xt = c1(xt)
|
685 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
686 |
+
xt = c2(xt)
|
687 |
+
if gamma is not None:
|
688 |
+
xt = gamma * xt
|
689 |
+
x = xt + x
|
690 |
+
return x
|
691 |
+
|
692 |
+
def remove_weight_norm(self):
|
693 |
+
for l in self.convs1:
|
694 |
+
remove_weight_norm(l)
|
695 |
+
for l in self.convs2:
|
696 |
+
remove_weight_norm(l)
|
697 |
+
|
698 |
+
@staticmethod
|
699 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
700 |
+
return int((kernel_size * dilation - dilation) / 2)
|
701 |
+
|
702 |
+
|
703 |
+
class Backbone(nn.Module):
|
704 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
705 |
+
|
706 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
707 |
+
"""
|
708 |
+
Args:
|
709 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
710 |
+
C denotes output features, and L is the sequence length.
|
711 |
+
|
712 |
+
Returns:
|
713 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
714 |
+
and H denotes the model dimension.
|
715 |
+
"""
|
716 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
717 |
+
|
718 |
+
|
719 |
+
class VocosBackbone(Backbone):
|
720 |
+
"""
|
721 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
722 |
+
|
723 |
+
Args:
|
724 |
+
input_channels (int): Number of input features channels.
|
725 |
+
dim (int): Hidden dimension of the model.
|
726 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
727 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
728 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
729 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
730 |
+
None means non-conditional model. Defaults to None.
|
731 |
+
"""
|
732 |
+
|
733 |
+
def __init__(
|
734 |
+
self,
|
735 |
+
input_channels: int,
|
736 |
+
dim: int,
|
737 |
+
intermediate_dim: int,
|
738 |
+
num_layers: int,
|
739 |
+
layer_scale_init_value: Optional[float] = None,
|
740 |
+
adanorm_num_embeddings: Optional[int] = None,
|
741 |
+
):
|
742 |
+
super().__init__()
|
743 |
+
self.input_channels = input_channels
|
744 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
745 |
+
self.adanorm = adanorm_num_embeddings is not None
|
746 |
+
if adanorm_num_embeddings:
|
747 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
748 |
+
else:
|
749 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
750 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
751 |
+
self.convnext = nn.ModuleList(
|
752 |
+
[
|
753 |
+
ConvNeXtBlock(
|
754 |
+
dim=dim,
|
755 |
+
intermediate_dim=intermediate_dim,
|
756 |
+
layer_scale_init_value=layer_scale_init_value,
|
757 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
758 |
+
)
|
759 |
+
for _ in range(num_layers)
|
760 |
+
]
|
761 |
+
)
|
762 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
763 |
+
self.apply(self._init_weights)
|
764 |
+
|
765 |
+
def _init_weights(self, m):
|
766 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
767 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
768 |
+
nn.init.constant_(m.bias, 0)
|
769 |
+
|
770 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
771 |
+
bandwidth_id = kwargs.get("bandwidth_id", None)
|
772 |
+
x = self.embed(x)
|
773 |
+
if self.adanorm:
|
774 |
+
assert bandwidth_id is not None
|
775 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
776 |
+
else:
|
777 |
+
x = self.norm(x.transpose(1, 2))
|
778 |
+
x = x.transpose(1, 2)
|
779 |
+
for conv_block in self.convnext:
|
780 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
781 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
782 |
+
return x
|
783 |
+
|
784 |
+
|
785 |
+
class VocosResNetBackbone(Backbone):
|
786 |
+
"""
|
787 |
+
Vocos backbone module built with ResBlocks.
|
788 |
+
|
789 |
+
Args:
|
790 |
+
input_channels (int): Number of input features channels.
|
791 |
+
dim (int): Hidden dimension of the model.
|
792 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
793 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
794 |
+
"""
|
795 |
+
|
796 |
+
def __init__(
|
797 |
+
self,
|
798 |
+
input_channels,
|
799 |
+
dim,
|
800 |
+
num_blocks,
|
801 |
+
layer_scale_init_value=None,
|
802 |
+
):
|
803 |
+
super().__init__()
|
804 |
+
self.input_channels = input_channels
|
805 |
+
self.embed = weight_norm(
|
806 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
807 |
+
)
|
808 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
809 |
+
self.resnet = nn.Sequential(
|
810 |
+
*[
|
811 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
812 |
+
for _ in range(num_blocks)
|
813 |
+
]
|
814 |
+
)
|
815 |
+
|
816 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
817 |
+
x = self.embed(x)
|
818 |
+
x = self.resnet(x)
|
819 |
+
x = x.transpose(1, 2)
|
820 |
+
return x
|
821 |
+
|
822 |
+
|
823 |
+
class Vocos(nn.Module):
|
824 |
+
def __init__(
|
825 |
+
self,
|
826 |
+
input_channels: int = 256,
|
827 |
+
dim: int = 384,
|
828 |
+
intermediate_dim: int = 1152,
|
829 |
+
num_layers: int = 8,
|
830 |
+
adanorm_num_embeddings: int = 4,
|
831 |
+
n_fft: int = 800,
|
832 |
+
hop_size: int = 200,
|
833 |
+
padding: str = "same",
|
834 |
+
):
|
835 |
+
super().__init__()
|
836 |
+
|
837 |
+
self.backbone = VocosBackbone(
|
838 |
+
input_channels=input_channels,
|
839 |
+
dim=dim,
|
840 |
+
intermediate_dim=intermediate_dim,
|
841 |
+
num_layers=num_layers,
|
842 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
843 |
+
)
|
844 |
+
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
|
845 |
+
|
846 |
+
def forward(self, x):
|
847 |
+
x = self.backbone(x)
|
848 |
+
x = self.head(x)
|
849 |
+
|
850 |
+
return x[:, None, :]
|
models/codec/ns3_codec/README.md
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3
|
2 |
+
|
3 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2403.03100.pdf)
|
4 |
+
[![demo](https://img.shields.io/badge/FACodec-Demo-red)](https://speechresearch.github.io/naturalspeech3/)
|
5 |
+
[![model](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/naturalspeech3_facodec)
|
6 |
+
[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)
|
7 |
+
|
8 |
+
## Overview
|
9 |
+
|
10 |
+
FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation.
|
11 |
+
|
12 |
+
Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E).
|
13 |
+
|
14 |
+
<br>
|
15 |
+
<div align="center">
|
16 |
+
<img src="../../../imgs/ns3/ns3_overview.png" width="65%">
|
17 |
+
</div>
|
18 |
+
<br>
|
19 |
+
|
20 |
+
<br>
|
21 |
+
<div align="center">
|
22 |
+
<img src="../../../imgs/ns3/ns3_facodec.png" width="100%">
|
23 |
+
</div>
|
24 |
+
<br>
|
25 |
+
|
26 |
+
## Useage
|
27 |
+
|
28 |
+
Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec)
|
29 |
+
|
30 |
+
Install Amphion
|
31 |
+
```bash
|
32 |
+
git clone https://github.com/open-mmlab/Amphion.git
|
33 |
+
```
|
34 |
+
|
35 |
+
Few lines of code to use the pre-trained FACodec model
|
36 |
+
```python
|
37 |
+
from Amphion.models.codec.ns3_codec import FACodecEncoder, FACodecDecoder
|
38 |
+
from huggingface_hub import hf_hub_download
|
39 |
+
|
40 |
+
fa_encoder = FACodecEncoder(
|
41 |
+
ngf=32,
|
42 |
+
up_ratios=[2, 4, 5, 5],
|
43 |
+
out_channels=256,
|
44 |
+
)
|
45 |
+
|
46 |
+
fa_decoder = FACodecDecoder(
|
47 |
+
in_channels=256,
|
48 |
+
upsample_initial_channel=1024,
|
49 |
+
ngf=32,
|
50 |
+
up_ratios=[5, 5, 4, 2],
|
51 |
+
vq_num_q_c=2,
|
52 |
+
vq_num_q_p=1,
|
53 |
+
vq_num_q_r=3,
|
54 |
+
vq_dim=256,
|
55 |
+
codebook_dim=8,
|
56 |
+
codebook_size_prosody=10,
|
57 |
+
codebook_size_content=10,
|
58 |
+
codebook_size_residual=10,
|
59 |
+
use_gr_x_timbre=True,
|
60 |
+
use_gr_residual_f0=True,
|
61 |
+
use_gr_residual_phone=True,
|
62 |
+
)
|
63 |
+
|
64 |
+
encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
|
65 |
+
decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")
|
66 |
+
|
67 |
+
fa_encoder.load_state_dict(torch.load(encoder_ckpt))
|
68 |
+
fa_decoder.load_state_dict(torch.load(decoder_ckpt))
|
69 |
+
|
70 |
+
fa_encoder.eval()
|
71 |
+
fa_decoder.eval()
|
72 |
+
|
73 |
+
```
|
74 |
+
|
75 |
+
Inference
|
76 |
+
```python
|
77 |
+
test_wav_path = "test.wav"
|
78 |
+
test_wav = librosa.load(test_wav_path, sr=16000)[0]
|
79 |
+
test_wav = torch.from_numpy(test_wav).float()
|
80 |
+
test_wav = test_wav.unsqueeze(0).unsqueeze(0)
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
|
84 |
+
# encode
|
85 |
+
enc_out = fa_encoder(test_wav)
|
86 |
+
print(enc_out.shape)
|
87 |
+
|
88 |
+
# quantize
|
89 |
+
vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
|
90 |
+
|
91 |
+
# latent after quantization
|
92 |
+
print(vq_post_emb.shape)
|
93 |
+
|
94 |
+
# codes
|
95 |
+
print("vq id shape:", vq_id.shape)
|
96 |
+
|
97 |
+
# get prosody code
|
98 |
+
prosody_code = vq_id[:1]
|
99 |
+
print("prosody code shape:", prosody_code.shape)
|
100 |
+
|
101 |
+
# get content code
|
102 |
+
cotent_code = vq_id[1:3]
|
103 |
+
print("content code shape:", cotent_code.shape)
|
104 |
+
|
105 |
+
# get residual code (acoustic detail codes)
|
106 |
+
residual_code = vq_id[3:]
|
107 |
+
print("residual code shape:", residual_code.shape)
|
108 |
+
|
109 |
+
# speaker embedding
|
110 |
+
print("speaker embedding shape:", spk_embs.shape)
|
111 |
+
|
112 |
+
# decode (recommand)
|
113 |
+
recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
|
114 |
+
print(recon_wav.shape)
|
115 |
+
sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)
|
116 |
+
```
|
117 |
+
|
118 |
+
FACodec can achieve zero-shot voice conversion with FACodecEncoderV2/FACodecDecoderV2 or FACodecRedecoder
|
119 |
+
```python
|
120 |
+
from Amphion.models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2
|
121 |
+
|
122 |
+
# Same parameters as FACodecEncoder/FACodecDecoder
|
123 |
+
fa_encoder_v2 = FACodecEncoderV2(...)
|
124 |
+
fa_decoder_v2 = FACodecDecoderV2(...)
|
125 |
+
|
126 |
+
encoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder_v2.bin")
|
127 |
+
decoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder_v2.bin")
|
128 |
+
|
129 |
+
fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt))
|
130 |
+
fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt))
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
enc_out_a = fa_encoder_v2(wav_a)
|
134 |
+
prosody_a = fa_encoder_v2.get_prosody_feature(wav_a)
|
135 |
+
enc_out_b = fa_encoder_v2(wav_b)
|
136 |
+
prosody_b = fa_encoder_v2.get_prosody_feature(wav_b)
|
137 |
+
|
138 |
+
vq_post_emb_a, vq_id_a, _, quantized, spk_embs_a = fa_decoder_v2(
|
139 |
+
enc_out_a, prosody_a, eval_vq=False, vq=True
|
140 |
+
)
|
141 |
+
vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2(
|
142 |
+
enc_out_b, prosody_b, eval_vq=False, vq=True
|
143 |
+
)
|
144 |
+
|
145 |
+
vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False)
|
146 |
+
recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b)
|
147 |
+
```
|
148 |
+
|
149 |
+
or
|
150 |
+
|
151 |
+
```python
|
152 |
+
from Amphion.models.codec.ns3_codec import FACodecRedecoder
|
153 |
+
|
154 |
+
fa_redecoder = FACodecRedecoder()
|
155 |
+
|
156 |
+
redecoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_redecoder.bin")
|
157 |
+
|
158 |
+
fa_redecoder.load_state_dict(torch.load(redecoder_ckpt))
|
159 |
+
|
160 |
+
with torch.no_grad():
|
161 |
+
enc_out_a = fa_encoder(wav_a)
|
162 |
+
enc_out_b = fa_encoder(wav_b)
|
163 |
+
|
164 |
+
vq_post_emb_a, vq_id_a, _, quantized_a, spk_embs_a = fa_decoder(enc_out_a, eval_vq=False, vq=True)
|
165 |
+
vq_post_emb_b, vq_id_b, _, quantized_b, spk_embs_b = fa_decoder(enc_out_b, eval_vq=False, vq=True)
|
166 |
+
|
167 |
+
# convert speaker
|
168 |
+
vq_post_emb_a_to_b = fa_redecoder.vq2emb(vq_id_a, spk_embs_b, use_residual=False)
|
169 |
+
recon_wav_a_to_b = fa_redecoder.inference(vq_post_emb_a_to_b, spk_embs_b)
|
170 |
+
|
171 |
+
sf.write("recon_a_to_b.wav", recon_wav_a_to_b[0][0].cpu().numpy(), 16000)
|
172 |
+
```
|
173 |
+
|
174 |
+
## Q&A
|
175 |
+
|
176 |
+
Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame?
|
177 |
+
|
178 |
+
A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame.
|
179 |
+
|
180 |
+
Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec?
|
181 |
+
|
182 |
+
A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes.
|
183 |
+
|
184 |
+
Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec?
|
185 |
+
|
186 |
+
A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model.
|
187 |
+
|
188 |
+
Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc.
|
189 |
+
|
190 |
+
A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio.
|
191 |
+
|
192 |
+
Q5: Can FACodec be used for content feature for some other tasks like voice conversion?
|
193 |
+
|
194 |
+
A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction.
|
195 |
+
|
196 |
+
## Citations
|
197 |
+
|
198 |
+
If you use our FACodec model, please cite the following paper:
|
199 |
+
|
200 |
+
```bibtex
|
201 |
+
@article{ju2024naturalspeech,
|
202 |
+
title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models},
|
203 |
+
author={Ju, Zeqian and Wang, Yuancheng and Shen, Kai and Tan, Xu and Xin, Detai and Yang, Dongchao and Liu, Yanqing and Leng, Yichong and Song, Kaitao and Tang, Siliang and others},
|
204 |
+
journal={arXiv preprint arXiv:2403.03100},
|
205 |
+
year={2024}
|
206 |
+
}
|
207 |
+
|
208 |
+
@article{zhang2023amphion,
|
209 |
+
title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
|
210 |
+
author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu},
|
211 |
+
journal={arXiv},
|
212 |
+
year={2024},
|
213 |
+
volume={abs/2312.09911}
|
214 |
+
}
|
215 |
+
```
|
216 |
+
|
models/codec/ns3_codec/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .facodec import *
|
models/codec/ns3_codec/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
from .filter import *
|
4 |
+
from .resample import *
|
5 |
+
from .act import *
|
models/codec/ns3_codec/alias_free_torch/act.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from .resample import UpSample1d, DownSample1d
|
5 |
+
|
6 |
+
|
7 |
+
class Activation1d(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
activation,
|
11 |
+
up_ratio: int = 2,
|
12 |
+
down_ratio: int = 2,
|
13 |
+
up_kernel_size: int = 12,
|
14 |
+
down_kernel_size: int = 12,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.up_ratio = up_ratio
|
18 |
+
self.down_ratio = down_ratio
|
19 |
+
self.act = activation
|
20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
22 |
+
|
23 |
+
# x: [B,C,T]
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.upsample(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.downsample(x)
|
28 |
+
|
29 |
+
return x
|
models/codec/ns3_codec/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
|
8 |
+
if "sinc" in dir(torch):
|
9 |
+
sinc = torch.sinc
|
10 |
+
else:
|
11 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
12 |
+
# https://adefossez.github.io/julius/julius/core.html
|
13 |
+
def sinc(x: torch.Tensor):
|
14 |
+
"""
|
15 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
16 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
17 |
+
"""
|
18 |
+
return torch.where(
|
19 |
+
x == 0,
|
20 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
21 |
+
torch.sin(math.pi * x) / math.pi / x,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
+
def kaiser_sinc_filter1d(
|
28 |
+
cutoff, half_width, kernel_size
|
29 |
+
): # return filter [1,1,kernel_size]
|
30 |
+
even = kernel_size % 2 == 0
|
31 |
+
half_size = kernel_size // 2
|
32 |
+
|
33 |
+
# For kaiser window
|
34 |
+
delta_f = 4 * half_width
|
35 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
36 |
+
if A > 50.0:
|
37 |
+
beta = 0.1102 * (A - 8.7)
|
38 |
+
elif A >= 21.0:
|
39 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
40 |
+
else:
|
41 |
+
beta = 0.0
|
42 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
43 |
+
|
44 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
45 |
+
if even:
|
46 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
47 |
+
else:
|
48 |
+
time = torch.arange(kernel_size) - half_size
|
49 |
+
if cutoff == 0:
|
50 |
+
filter_ = torch.zeros_like(time)
|
51 |
+
else:
|
52 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
53 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
54 |
+
# of the constant component in the input signal.
|
55 |
+
filter_ /= filter_.sum()
|
56 |
+
filter = filter_.view(1, 1, kernel_size)
|
57 |
+
|
58 |
+
return filter
|
59 |
+
|
60 |
+
|
61 |
+
class LowPassFilter1d(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
cutoff=0.5,
|
65 |
+
half_width=0.6,
|
66 |
+
stride: int = 1,
|
67 |
+
padding: bool = True,
|
68 |
+
padding_mode: str = "replicate",
|
69 |
+
kernel_size: int = 12,
|
70 |
+
):
|
71 |
+
# kernel_size should be even number for stylegan3 setup,
|
72 |
+
# in this implementation, odd number is also possible.
|
73 |
+
super().__init__()
|
74 |
+
if cutoff < -0.0:
|
75 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
76 |
+
if cutoff > 0.5:
|
77 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
78 |
+
self.kernel_size = kernel_size
|
79 |
+
self.even = kernel_size % 2 == 0
|
80 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
81 |
+
self.pad_right = kernel_size // 2
|
82 |
+
self.stride = stride
|
83 |
+
self.padding = padding
|
84 |
+
self.padding_mode = padding_mode
|
85 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
86 |
+
self.register_buffer("filter", filter)
|
87 |
+
|
88 |
+
# input [B, C, T]
|
89 |
+
def forward(self, x):
|
90 |
+
_, C, _ = x.shape
|
91 |
+
|
92 |
+
if self.padding:
|
93 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
94 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
95 |
+
|
96 |
+
return out
|
models/codec/ns3_codec/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from .filter import LowPassFilter1d
|
6 |
+
from .filter import kaiser_sinc_filter1d
|
7 |
+
|
8 |
+
|
9 |
+
class UpSample1d(nn.Module):
|
10 |
+
def __init__(self, ratio=2, kernel_size=None):
|
11 |
+
super().__init__()
|
12 |
+
self.ratio = ratio
|
13 |
+
self.kernel_size = (
|
14 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
)
|
16 |
+
self.stride = ratio
|
17 |
+
self.pad = self.kernel_size // ratio - 1
|
18 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
19 |
+
self.pad_right = (
|
20 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
21 |
+
)
|
22 |
+
filter = kaiser_sinc_filter1d(
|
23 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
24 |
+
)
|
25 |
+
self.register_buffer("filter", filter)
|
26 |
+
|
27 |
+
# x: [B, C, T]
|
28 |
+
def forward(self, x):
|
29 |
+
_, C, _ = x.shape
|
30 |
+
|
31 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
32 |
+
x = self.ratio * F.conv_transpose1d(
|
33 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
34 |
+
)
|
35 |
+
x = x[..., self.pad_left : -self.pad_right]
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class DownSample1d(nn.Module):
|
41 |
+
def __init__(self, ratio=2, kernel_size=None):
|
42 |
+
super().__init__()
|
43 |
+
self.ratio = ratio
|
44 |
+
self.kernel_size = (
|
45 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
46 |
+
)
|
47 |
+
self.lowpass = LowPassFilter1d(
|
48 |
+
cutoff=0.5 / ratio,
|
49 |
+
half_width=0.6 / ratio,
|
50 |
+
stride=ratio,
|
51 |
+
kernel_size=self.kernel_size,
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
xx = self.lowpass(x)
|
56 |
+
|
57 |
+
return xx
|